Skip to content

Commit

Permalink
adjust to review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonDold committed Jul 2, 2024
1 parent 7fd9d24 commit dd1fda2
Showing 1 changed file with 56 additions and 67 deletions.
123 changes: 56 additions & 67 deletions misc/tests/test-parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,41 @@
"""

import argparse
import os
import os.path
from pathlib import Path
import re
import subprocess
import sys

DIR = os.path.dirname(os.path.abspath(__file__))
REPO = os.path.dirname(os.path.dirname(DIR))
SRC_DIR = os.path.join(REPO, "src")

SHORT_HANDS = {
"ipdb": "cpdbs(hillclimbing())",
"astar": "eager(tiebreaking([sum([g(), h]), h],"
" unsafe_pruning=false), reopen_closed=true,"
" f_eval=sum([g(), h]))",
"lazy_greedy": "lazy(alt([single(h1), single(h1,"
" pref_only=true), single(h2), single(h2,"
" pref_only=true)], boost=100), preferred=h2)",
"lazy_wastar": "lazy(single(sum([g(), weight(eval1, 2)])),"
" reopen_closed=true)",
"eager_greedy": "eager(single(eval1))",
"eager_wastar": """ See corresponding notes for"""
""" "(Weighted) A* search (lazy)" """,
}
DIR = Path(__file__).resolve().parent
REPO = DIR.parents[1]
SRC_DIR = REPO / "src"

SHORT_HANDS = [
"ipdb", # cpdbs(hillclimbing())
"astar", # eager(tiebreaking([sum([g(), h]), h],
# unsafe_pruning=false), reopen_closed=true,
# f_eval=sum([g(), h]))
"lazy_greedy", # lazy(alt([single(h1), single(h1,
# pref_only=true), single(h2), single(h2,
# pref_only=true)], boost=100), preferred=h2)
"lazy_wastar", # lazy(single(sum([g(), weight(eval1, 2)])),
# reopen_closed=true)
"eager_greedy", # eager(single(eval1))
"eager_wastar", # See corresponding notes for
# "(Weighted) A* search (lazy)"
# eager_wastar(evals=[eval1, eval2],prefered=pref1,
# reopen_closed=rc1, boost=boo1, w=w1,
# pruning=pru1, cost_type=ct1, bound = bou1,
# max_time=mt1, verbosity=v1)
# Is equivalent to:
# eager(open = alt([single(sum([g(), weight(eval1, w1)])),
# single(sum([g(), weight(eval2, w1)]))],
# boost=boo1),
# reopen_closed=rc1, f_eval = <none>,
# preferred = pref1, pruning = pru1,
# cost_type=ct1, bound=bou1, max_time=mt1,
# verbosity=v1)
]

TEMPORARY_EXCEPTIONS = [
"iterated",
Expand All @@ -46,19 +57,7 @@
]

CREATE_COMPONENT_REGEX = r"(^|\s|\W)create_component"
C_VAR_PATTERN = r'[^a-zA-Z0-9_]' # overapproximation

def get_src_files(path, extensions, ignore_dirs=None):
ignore_dirs = ignore_dirs or []
src_files = []
for root, dirs, files in os.walk(path):
for ignore_dir in ignore_dirs:
if ignore_dir in dirs:
dirs.remove(ignore_dir)
src_files.extend([
os.path.join(root, file)
for file in files if file.endswith(extensions)])
return src_files
NON_C_VAR_PATTERN = r'[^a-zA-Z0-9_]' # overapproximation

def extract_cpp_class(input_string):
pattern = r'<(.*?)>'
Expand Down Expand Up @@ -113,7 +112,6 @@ def extract_feature_name_and_cpp_class(cc_file, cc_files, cwd, num):
class_pattern = r'TypedFeature<(.*?)> {'
feature_names = []
class_names = []
other_namespaces = []
feature_error_msgs = []
class_error_msgs = []
for line in source_without_comments.splitlines():
Expand All @@ -125,17 +123,14 @@ def extract_feature_name_and_cpp_class(cc_file, cc_files, cwd, num):
if re.search(class_pattern, line):
feature_class = re.search(class_pattern, line).group(1)
class_name = feature_class.split()[-1].split("::")[-1]
other_namespace = (len(feature_class.split()[-1].split("::"))
== 2)
class_error_msg = "class_name: " + class_name + "\n"
class_names.append(class_name)
other_namespaces.append(other_namespace)
class_error_msgs.append(class_error_msg)
return (feature_names[num], class_names[num], other_namespaces[num],
return (feature_names[num], class_names[num],
feature_error_msgs[num] + class_error_msgs[num])

def get_cpp_class_parameters(
class_name, other_namespace, cc_file, cc_files, cwd):
class_name, cc_file, cc_files, cwd):
found_in_file, parameters = get_constructor_parameters(
cc_file, class_name)
if not found_in_file:
Expand All @@ -149,7 +144,7 @@ def get_cpp_class_parameters(
parameters = parameters.replace("\n", "") + ","
parameters = parameters.split()
parameters = [word for word in parameters if "," in word]
parameters = [re.sub(C_VAR_PATTERN, '', word)
parameters = [re.sub(NON_C_VAR_PATTERN, '', word)
for word in parameters]
return parameters
else:
Expand All @@ -166,28 +161,18 @@ def get_create_component_lines(cc_file):
return lines

def compare_component_parameters(cc_file, cc_files, cwd):
found_error = False
error_msg = ""
create_component_lines = get_create_component_lines(cc_file)
if not create_component_lines == []:
if create_component_lines:
for i, create_component_line in enumerate(
create_component_lines):
(feature_name, cpp_class, other_namespace,
extracted_error_msg) = (
(feature_name, cpp_class, extracted_error_msg) = (
extract_feature_name_and_cpp_class(
cc_file, cc_files, cwd, i))
error_msg += "\n\n=====================================\n"
error_msg += "= = = " + cpp_class + " = = =\n"
error_msg += extracted_error_msg + "\n"
feature_parameters = extract_feature_parameter_list(
feature_name)
error_msg += ("== FEATURE PARAMETERS '"
+ feature_name + "'==\n")
error_msg += str(feature_parameters) + "\n"
feature_name)
cpp_class_parameters = get_cpp_class_parameters(
cpp_class, other_namespace, cc_file, cc_files, cwd)
error_msg += "== CLASS PARAMETERS '" + cpp_class + "'==\n"
error_msg += str(cpp_class_parameters) + "\n"
cpp_class, cc_file, cc_files, cwd)
if feature_name in SHORT_HANDS:
print(f"feature_name '{feature_name}' is ignored"
" because it is marked as shorthand")
Expand All @@ -198,7 +183,15 @@ def compare_component_parameters(cc_file, cc_files, cwd):
print(f"feature_name '{feature_name}' is ignored"
" because it is marked as TEMPORARY_EXCEPTION")
elif feature_parameters != cpp_class_parameters:
found_error = True
error_msg += ( "\n\n=====================================\n"
+ "= = = " + cpp_class + " = = =\n"
+ extracted_error_msg + "\n"
+ "== FEATURE PARAMETERS '"
+ feature_name + "' ==\n"
+ str(feature_parameters) + "\n"
+ "== CLASS PARAMETERS '"
+ cpp_class + "' ==\n"
+ str(cpp_class_parameters) + "\n")
if not len(feature_parameters) == len(cpp_class_parameters):
error_msg += "Wrong sizes\n"
for i in range(min(len(feature_parameters),
Expand All @@ -207,22 +200,17 @@ def compare_component_parameters(cc_file, cc_files, cwd):
error_msg += (feature_parameters[i] +
" =/= " + cpp_class_parameters[i] + "\n")
error_msg += cc_file + "\n"
return found_error, error_msg
return error_msg

def error_check(cc_files, cwd):
errors = []
for cc_file in cc_files:
found_error, error = compare_component_parameters(
error_msg = compare_component_parameters(
cc_file, cc_files, cwd)
if found_error:
errors.append(error)
if error_msg:
errors.append(error_msg)
if errors:
print("######################################################")
print("######################################################")
print("######################################################")
print("######################################################")
print("######################################################")
print(".: ERRORS :.")
print("############### ERRORS ##########################")
for error in errors:
print(error)
sys.exit(1)
Expand All @@ -232,8 +220,9 @@ def main():
Currently, we only check that the parameters in the Constructor in
the .cc file matches the parameters for the CLI.
"""
search_dir = os.path.join(SRC_DIR, "search")
cc_files = get_src_files(search_dir, (".cc",))
search_dir = SRC_DIR / "search"
cc_files = [str(file) for file in search_dir.rglob('*.cc')
if file.is_file()]
assert len(cc_files) > 0
print("Checking Component Parameters of"
" {} *.cc files".format(len(cc_files)))
Expand Down

0 comments on commit dd1fda2

Please sign in to comment.