Skip to content

Commit

Permalink
[Feature](mluAdamW): update adam_w_lite
Browse files Browse the repository at this point in the history
  • Loading branch information
chqy99 committed Jan 23, 2025
1 parent b8b44a7 commit 6a1ee01
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions scripts/bangc_kernels_path_config/bangc_kernels_path_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def get_relative_paths(absolute_path):
for folder, dirs, files in os.walk(absolute_path):
# get relative_folder path
relative_folder = os.path.relpath(folder, start=absolute_path)
if "kernels" in relative_folder:
relative_paths.append(relative_folder)
print(relative_folder)

# concat path
for file in files:
Expand Down Expand Up @@ -93,13 +96,13 @@ def extract_headers(json_file_path, common_flag=True, header_flag=True, sources_
op_name.extend(operator['name'])
else:
op_name.append(operator['name'])

if 'header' in operator and header_flag:
if isinstance(operator['header'], list):
header_files.extend(operator['header'])
else:
header_files.append(operator['header'])

if 'sources' in operator and sources_flag:
if isinstance(operator['sources'], list):
header_files.extend(operator['sources'])
Expand All @@ -111,7 +114,7 @@ def extract_headers(json_file_path, common_flag=True, header_flag=True, sources_

'''
params:
1. header_files_all: all .h/.mlu/.mluh/.cpp paths under "mlu-ops/"
1. header_files_all: all .h/.mlu/.mluh/.cpp paths under "mlu-ops/"
2. header_files: all .h/.mlu/.mluh/.cpp paths form JSON
3. header_files_unique: Deduplicated header_files
'''
Expand All @@ -130,7 +133,7 @@ def extract_headers(json_file_path, common_flag=True, header_flag=True, sources_
print(path)
header_files.extend(files_path) if isinstance(files_path, list) else header_files.append(files_path)
op_name_list.extend(op_name) if isinstance(op_name, list) else header_files.append(op_name)

# get header_files_unique
header_files_unique = list(set(header_files))
assert len(header_files_unique) is len(header_files), "There are duplicate paths in JSON files {}. ".format(json_paths)
Expand All @@ -140,4 +143,4 @@ def extract_headers(json_file_path, common_flag=True, header_flag=True, sources_
print("Bangc kernels path check success.")

except Exception as e:
print(f"[ERROR] Check failed. : {e}")
print(f"[ERROR] Check failed. : {e}")

0 comments on commit 6a1ee01

Please sign in to comment.