mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2024-08-29 01:18:33 +03:00
128 lines
4.3 KiB
Python
128 lines
4.3 KiB
Python
# Copyright 2023 https://github.com/ShishirPatil/gorilla
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_hf.py
|
|
|
|
from tree_sitter import Language, Parser
|
|
|
|
|
|
# Get all the subtrees given a root_node
|
|
def get_all_sub_trees(root_node):
|
|
node_stack = []
|
|
sub_tree_sexp_list = []
|
|
depth = 1
|
|
# text = root_node.text
|
|
node_stack.append([root_node, depth])
|
|
while len(node_stack) != 0:
|
|
cur_node, cur_depth = node_stack.pop()
|
|
if cur_node.child_count > 0:
|
|
sub_tree_sexp_list.append(
|
|
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
|
|
)
|
|
else:
|
|
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
|
|
for child_node in cur_node.children:
|
|
if len(child_node.children) != 0:
|
|
depth = cur_depth + 1
|
|
node_stack.append([child_node, depth])
|
|
return sub_tree_sexp_list
|
|
|
|
|
|
# Parse the program into AST trees
|
|
def ast_parse(candidate, lang='python'):
|
|
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
|
|
parser = Parser()
|
|
parser.set_language(LANGUAGE)
|
|
|
|
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
|
|
return candidate_tree
|
|
|
|
|
|
# Get all the arguments in the ast tree
|
|
def get_args(node):
|
|
if node.child_count == 0:
|
|
return []
|
|
args_list = []
|
|
for child in node.children[0].children[0].children[1].children:
|
|
if '=' in child.text.decode():
|
|
args_list.append(child.children[2].text)
|
|
elif (
|
|
child.text.decode() != '('
|
|
and child.text.decode() != ')'
|
|
and child.text.decode() != ','
|
|
):
|
|
args_list.append(child.text)
|
|
return args_list
|
|
|
|
|
|
# Check if there is an api match
|
|
def ast_check(candidate_subtree_list, base_tree_list):
|
|
for idx, base_tree in enumerate(base_tree_list):
|
|
if base_tree.children[0].children[0].child_count == 0:
|
|
continue
|
|
api_name = base_tree.children[0].children[0].children[0].text
|
|
for candidate_tree in candidate_subtree_list:
|
|
if candidate_tree[3] == api_name:
|
|
break
|
|
# Now we have a sub-tree
|
|
candidate_tree = candidate_tree[2]
|
|
args_list = get_args(base_tree)
|
|
if len(args_list) == 0:
|
|
continue
|
|
ast_match = True
|
|
for arg in args_list:
|
|
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
|
|
ast_match = False
|
|
break
|
|
if ast_match:
|
|
return idx
|
|
return -1
|
|
|
|
|
|
def ast_eval_hf(api_database, qa_pairs, ast_database, question_id, response):
|
|
# Check correctness
|
|
correct = False
|
|
hallucination = False
|
|
output = response
|
|
# Index the "api_call" domain
|
|
output = output.split('api_call')
|
|
if len(output) == 1:
|
|
api_call = output[0]
|
|
else:
|
|
# Parse the output
|
|
output = output[1].split('api_provider')[0]
|
|
if ':' not in output:
|
|
start = 0
|
|
else:
|
|
start = output.index(':')
|
|
if ')' not in output:
|
|
end = -2
|
|
else:
|
|
end = output.rindex(')')
|
|
api_call = output[start + 2 : end + 1]
|
|
# Parse the api_call into AST tree
|
|
ast_tree = ast_parse(api_call)
|
|
# Search for a subtree
|
|
ast_subtree_list = get_all_sub_trees(ast_tree)
|
|
# Check which ast tree is matching
|
|
database_index = ast_check(ast_subtree_list, ast_database)
|
|
# We cannot index this ast in our database
|
|
if database_index == -1:
|
|
hallucination = True
|
|
# We index our reference api_call
|
|
ref_api_call = api_database[database_index]
|
|
# Check for functionality
|
|
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
|
correct = True
|
|
return correct, hallucination
|