Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defensive checks for ml labels in config #20

Merged
merged 18 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ cmake_minimum_required(VERSION 3.21)
project(TTool VERSION "0.1.0" LANGUAGES CXX)

# zenodo direct doi link (to change if dataset us updated)
set(TTOOL_DSET_DOI "https://zenodo.org/record/10014284" CACHE INTERNAL "")
set(TTOOL_DSET_DOI "https://zenodo.org/doi/10.5281/zenodo.7956930" CACHE INTERNAL "")

# Not sur it is a good idea, it will fail on installed version (NR)
set(__TTOOL_CONFIG_PATH__ "${CMAKE_CURRENT_SOURCE_DIR}/assets/config.yml")
set(__TTOOL_ROOT_PATH__ "${CMAKE_CURRENT_SOURCE_DIR}")

include(cmake/dataset.cmake)
include(cmake/classifier.cmake)

file(GLOB glsl_SOURCES src/shader/*.glsl)
source_group("shader" FILES ${glsl_SOURCES})
Expand Down
11 changes: 0 additions & 11 deletions ai/torchscripts/label_map.txt

This file was deleted.

11 changes: 11 additions & 0 deletions ai/torchscripts/labels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
auger_drill_bit_34_235
chain_swordsaw_blade_200
spade_drill_bit_25_150
brad_point_drill_bit_20_150
chain_saw_blade_f_250
self_feeding_bit_40_90
circular_saw_blade_makita_190
self_feeding_bit_50_90
twist_drill_bit_32_165
saber_saw_blade_makita_t_300
auger_drill_bit_20_235
90 changes: 50 additions & 40 deletions assets/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ histOffset: 100
histRad: 40
searchRad: 25
classifierModelPath: "ai/torchscripts/efficientnet.pt"
classifierLabelsPath: "ai/torchscripts/labels.txt"

# The order of classifierLabels is determined by the classifier during training
# the classifier once trained output a .pt (weights) and a .txt file for the order
# see the .txt file in ai/torchscripts to see the order to recreate here below
classifierLabels:
- "auger_drill_bit_34_235"
- "chain_swordsaw_blade_200"
Expand All @@ -22,6 +20,7 @@ classifierLabels:
- "twist_drill_bit_32_165"
- "saber_saw_blade_makita_t_300"
- "auger_drill_bit_20_235"

classifierImageSize: 384
classifierImageChannels: 3
classifierMean:
Expand All @@ -34,49 +33,60 @@ classifierStd:
- 0.225
groundTruthPoses:
- [ -7.73640037e-01, 4.15423065e-01, -4.78431463e-01, -5.23223341e-01,
-8.44716668e-01, 1.12603128e-01, -3.57363552e-01, 3.37440640e-01,
8.70876431e-01, 0., 0., 2.20000029e-01 ]
-8.44716668e-01, 1.12603128e-01, -3.57363552e-01, 3.37440640e-01,
8.70876431e-01, 0., 0., 2.20000029e-01 ]
- [ 9.50834990e-01, 2.37247661e-01, -1.99040413e-01, 2.44926587e-02,
-6.98316872e-01, -7.15365469e-01, -3.08712870e-01, 6.75321281e-01,
-6.69801235e-01, -2.26506889e-02, 2.32661795e-02, 1.76990986e-01 ]
-6.98316872e-01, -7.15365469e-01, -3.08712870e-01, 6.75321281e-01,
-6.69801235e-01, -2.26506889e-02, 2.32661795e-02, 1.76990986e-01 ]
- [ 9.84866619e-01, -4.03894819e-02, 1.68541461e-01, -1.62771076e-01,
1.18437737e-01, 9.79524612e-01, -5.95245212e-02, -9.92136419e-01,
1.10068895e-01, 0., 0., 1.39999986e-01 ]
- [ 9.84866619e-01, -4.03894819e-02, 1.68541461e-01, -1.62771076e-01,
1.18437737e-01, 9.79524612e-01, -5.95245212e-02, -9.92136419e-01,
1.10068895e-01, 0., 0., 1.39999986e-01 ]
- [ -5.42542815e-01, 8.39947045e-01, -1.11253709e-02, 5.06054997e-01,
3.16238910e-01, -8.02426398e-01, -6.70481861e-01, -4.40981954e-01,
-5.96643448e-01, -1.26680557e-03, 4.37926613e-02, 2.52720535e-01 ]
1.18437737e-01, 9.79524612e-01, -5.95245212e-02, -9.92136419e-01,
1.10068895e-01, 0., 0., 1.39999986e-01 ]
- [ 8.19440663e-01, 5.45331001e-01, -1.76395491e-01, 1.42224208e-01,
-4.91604596e-01, -8.59123707e-01, -5.55223823e-01, 6.78913355e-01,
-4.80406970e-01, 1.67723373e-02, 2.61279512e-02, 1.92897707e-01 ]
-4.91604596e-01, -8.59123707e-01, -5.55223823e-01, 6.78913355e-01,
-4.80406970e-01, 1.67723373e-02, 2.61279512e-02, 1.92897707e-01 ]
- [ 9.84866619e-01, -4.03894819e-02, 1.68541461e-01, -1.62771076e-01,
1.18437737e-01, 9.79524612e-01, -5.95245212e-02, -9.92136419e-01,
1.10068895e-01, 0., 0., 1.39999986e-01 ]
- [ 8.19440663e-01, 5.45331001e-01, -1.76395491e-01, 1.42224208e-01,
-4.91604596e-01, -8.59123707e-01, -5.55223823e-01, 6.78913355e-01,
-4.80406970e-01, 1.67723373e-02, 2.61279512e-02, 1.92897707e-01 ]

- [ 9.53821719e-01, 2.47012243e-01, 1.70877323e-01, 2.97439009e-01,
-6.97671890e-01, -6.51752949e-01, -4.17775214e-02, 6.72493160e-01,
-7.38922477e-01, -2.76020560e-02, 3.86725217e-02, 1.33150429e-01 ]
-6.97671890e-01, -6.51752949e-01, -4.17775214e-02, 6.72493160e-01,
-7.38922477e-01, -2.76020560e-02, 3.86725217e-02, 1.33150429e-01 ]
- [ -8.07440519e-01, 4.48605478e-01, -3.83105516e-01, -5.85106611e-01,
-6.91857159e-01, 4.23044086e-01, -7.52737448e-02, 5.65754294e-01,
8.21124077e-01, -5.82594611e-03, 2.99568456e-02, 1.46538332e-01 ]
-6.91857159e-01, 4.23044086e-01, -7.52737448e-02, 5.65754294e-01,
8.21124077e-01, -5.82594611e-03, 2.99568456e-02, 1.46538332e-01 ]
- [ -8.28246951e-01, 2.30366185e-01, 5.10807216e-01, 2.39123389e-01,
-6.79095149e-01, 6.94000125e-01, 5.06760836e-01, 6.96950197e-01,
5.07382751e-01, -2.56683957e-02, 2.72371043e-02, 1.72090665e-01 ]
-6.79095149e-01, 6.94000125e-01, 5.06760836e-01, 6.96950197e-01,
5.07382751e-01, -2.56683957e-02, 2.72371043e-02, 1.72090665e-01 ]
- [ -9.45800602e-01, 2.85116285e-01, 1.55401364e-01, -8.82991254e-02,
-6.86366022e-01, 7.21853733e-01, 3.12482119e-01, 6.69024289e-01,
6.74335539e-01, -2.39054989e-02, 4.05580476e-02, 1.38083279e-01 ]
-6.86366022e-01, 7.21853733e-01, 3.12482119e-01, 6.69024289e-01,
6.74335539e-01, -2.39054989e-02, 4.05580476e-02, 1.38083279e-01 ]
modelFiles:
- "assets/toolheads/saber_saw_blade_makita_t_300/model.obj"
- "assets/toolheads/twist_drill_bit_32_165/model.obj"
- "assets/toolheads/circular_saw_blade_makita_190/model.obj"
- "assets/toolheads/chain_saw_blade_f_250/model.obj"
- "assets/toolheads/auger_drill_bit_20_235/model.obj"
- "assets/toolheads/brad_point_drill_bit_20_150/model.obj"
- "assets/toolheads/spade_drill_bit_25_150/model.obj"
- "assets/toolheads/self_feeding_bit_40_90/model.obj"
- "assets/toolheads/self_feeding_bit_50_90/model.obj"
- "/assets/toolheads/saber_saw_blade_makita_t_300/model.obj"
- "/assets/toolheads/twist_drill_bit_32_165/model.obj"
- "/assets/toolheads/circular_saw_blade_makita_190/model.obj"
- "/assets/toolheads/chain_saw_blade_f_250/model.obj"
- "/assets/toolheads/auger_drill_bit_20_235/model.obj"
- "/assets/toolheads/chain_swordsaw_blade_200/model.obj"
- "/assets/toolheads/auger_drill_bit_34_235/model.obj"
- "/assets/toolheads/brad_point_drill_bit_20_150/model.obj"
- "/assets/toolheads/spade_drill_bit_25_150/model.obj"
- "/assets/toolheads/self_feeding_bit_40_90/model.obj"
- "/assets/toolheads/self_feeding_bit_50_90/model.obj"
acitFiles:
- "assets/toolheads/saber_saw_blade_makita_t_300/metadata.acit"
- "assets/toolheads/twist_drill_bit_32_165/metadata.acit"
- "assets/toolheads/circular_saw_blade_makita_190/metadata.acit"
- "assets/toolheads/chain_saw_blade_f_250/metadata.acit"
- "assets/toolheads/auger_drill_bit_20_235/metadata.acit"
- "assets/toolheads/brad_point_drill_bit_20_150/metadata.acit"
- "assets/toolheads/spade_drill_bit_25_150/metadata.acit"
- "assets/toolheads/self_feeding_bit_40_90/metadata.acit"
- "assets/toolheads/self_feeding_bit_50_90/metadata.acit"
- "/assets/toolheads/saber_saw_blade_makita_t_300/metadata.acit"
- "/assets/toolheads/twist_drill_bit_32_165/metadata.acit"
- "/assets/toolheads/circular_saw_blade_makita_190/metadata.acit"
- "/assets/toolheads/chain_saw_blade_f_250/metadata.acit"
- "/assets/toolheads/auger_drill_bit_20_235/metadata.acit"
- "/assets/toolheads/chain_swordsaw_blade_200/metadata.acit"
- "/assets/toolheads/auger_drill_bit_34_235/metadata.acit"
- "/assets/toolheads/brad_point_drill_bit_20_150/metadata.acit"
- "/assets/toolheads/spade_drill_bit_25_150/metadata.acit"
- "/assets/toolheads/self_feeding_bit_40_90/metadata.acit"
- "/assets/toolheads/self_feeding_bit_50_90/metadata.acit"
8 changes: 8 additions & 0 deletions cmake/classifier.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

if(UNIX AND NOT APPLE)
set(LOADER_CMD "${PROJECT_SOURCE_DIR}/util/load_labels_2_config.py")
execute_process(
COMMAND chmod +x ${LOADER_CMD}
COMMAND ${LOADER_CMD} -s ${CMAKE_CURRENT_SOURCE_DIR} -c ${__TTOOL_CONFIG_PATH__}
)
endif()
80 changes: 79 additions & 1 deletion include/config.hh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#include <variant>
#include <functional>
#include <unordered_map>
#include <filesystem>
#include <algorithm>
#include <unordered_set>
#include <fstream>

namespace ttool
{
Expand Down Expand Up @@ -203,6 +207,77 @@ namespace ttool
LoadConfigFile();
}

/**
* @brief Check if the acit names match the folder names
*
*/
void CheckAcitFiles(const std::string& TToolRootPath)
{

std::filesystem::path rootPath = std::filesystem::current_path() / TToolRootPath;

std::string line;
std::string toolheadNameTagStart = "<toolhead name=\"";
std::string toolheadNameTagEnd = "\"";

for (const auto& acitFile : m_ConfigData.AcitFiles) {
std::string acitFileR = acitFile.substr(1);
std::filesystem::path acitFilePath = rootPath / acitFileR;
std::string toolheadName = "";

std::ifstream fs(acitFilePath);
if (!fs.is_open()) {
throw std::runtime_error("Could not open file: " + acitFilePath.string());
}

while (std::getline(fs, line)) {
size_t start = line.find(toolheadNameTagStart);
if (start != std::string::npos) {
start += toolheadNameTagStart.length();
size_t end = line.find(toolheadNameTagEnd, start);
if (end != std::string::npos) {
toolheadName = line.substr(start, end - start);
break;
}
}
}
fs.close();

if (acitFile.find(toolheadName) == std::string::npos) {
throw std::runtime_error("Toolhead name mismatch error: Toolhead name \"" + toolheadName +
"\" does not match the folder name \"" + acitFile + "\"");
}
}
}

/**
* @brief Check if the labels in the config file match the file paths of model files and acit files
*
*/
void CheckClassifierLabelsConfig()
{
std::unordered_set<std::string> filePaths;

for (const auto& modelFile : m_ConfigData.ModelFiles) {
filePaths.insert( modelFile);
}
for (const auto& acitFile : m_ConfigData.AcitFiles) {
filePaths.insert(acitFile);
}

for (const auto& label : m_ConfigData.ClassifierLabels) {
bool labelMatches = std::any_of(filePaths.begin(), filePaths.end(),
[&label](const std::string& filePath) {
return filePath.find(label) != std::string::npos;
});

if (!labelMatches) {
throw std::runtime_error("Label mismatch error: Label \"" + label + "\" does not match any file paths");
}
}

}

/**
* @brief Read the config file and set the values to the ConfigData object
*
Expand Down Expand Up @@ -248,7 +323,7 @@ namespace ttool
return fs.release();
}

/**
/**
* @brief Print the config file to the console
*
*/
Expand Down Expand Up @@ -341,17 +416,20 @@ namespace ttool
*/
ConfigData GetConfigData()
{
std::vector<std::string> fileNames;
// Create a copy of the ConfigData object
ConfigData configData = this->m_ConfigData;
// Prefix the model files with the m_TToolRootPath
for (auto& modelFile : configData.ModelFiles)
{
modelFile = std::string(m_TToolRootPath) + "/" + modelFile;
fileNames.push_back(modelFile);
}
// Prefix the acit files with the m_TToolRootPath
for (auto& acitFile : configData.AcitFiles)
{
acitFile = std::string(m_TToolRootPath) + "/" + acitFile;
fileNames.push_back(acitFile);
}
// Prefix the classifier model path with the m_TToolRootPath
configData.ClassifierModelPath = std::string(m_TToolRootPath) + "/" + configData.ClassifierModelPath;
Expand Down
2 changes: 2 additions & 0 deletions include/ttool.hh
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ namespace ttool
m_ConfigFile = configFile;
m_ConfigPtr = std::make_shared<ttool::Config>(configFile);
m_ConfigPtr->SetTToolRootPath(ttoolRootPath);
m_ConfigPtr->CheckAcitFiles(ttoolRootPath);
m_ConfigPtr->CheckClassifierLabelsConfig();
}

/**
Expand Down
88 changes: 88 additions & 0 deletions util/load_labels_2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python3

import os
import sys
import argparse


def _log_process(msg):
print("[PROCESS:util/load_dataset_2_config.py] {}".format(msg))
def _log_info(msg):
print("\033[95m[INFO:util/load_dataset_2_config.py] {}\033[00m".format(msg))
def _log_error(msg):
print("\033[91m[ERROR:util/load_dataset_2_config.py] {}\033[00m".format(msg))
def _log_warning(msg):
print("\033[93m[WARNING:util/load_dataset_2_config.py] {}\033[00m".format(msg))
def _log_success(msg):
print("\033[92m[SUCCESS:util/load_dataset_2_config.py] {}\033[00m".format(msg))


def main(source_path: str, config_path: str) -> None:
_log_process("Loading the labels file...")

try:
classifier_labels_path_key: str = 'classifierLabelsPath:'
classifier_labels: str = 'classifierLabels:'
with open(config_path, "r") as f:
config_lines = f.readlines()

labels_path = None
for line in config_lines:
if classifier_labels_path_key in line:
_, labels_path = line.split(':', 1)
labels_path = labels_path.strip().strip('\"')
break

if not labels_path:
raise ValueError(f"{classifier_labels_path_key} not found in the configuration file")

with open(os.path.join(source_path, labels_path), "r") as f:
labels = [line.strip() for line in f if line.strip()]

start_idx = -1
end_idx = -1
for i, line in enumerate(config_lines):
if line.strip() == classifier_labels:
start_idx = i
end_idx = start_idx + 1
while end_idx < len(config_lines) and config_lines[end_idx].strip().startswith('-'):
end_idx += 1
break

new_labels_section = [f'{classifier_labels}\n'] + \
[f" - \"{label}\"\n" for label in labels]

if start_idx != -1:
config_lines = config_lines[:start_idx] + new_labels_section + config_lines[end_idx:]
else:
config_lines.extend(['\n'] + new_labels_section)

with open(config_path, "w") as f:
f.writelines(config_lines)

_log_success("The labels were successfully loaded into the assets/config.yml file.")

except FileNotFoundError as e:(
_log_error(f"Error: File not found - {e}"))
except ValueError as e:(
_log_error(f"Error: {e}"))
except Exception as e:(
_log_error(f"An unexpected error occurred: {e}"))



if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Load the labels into the TTool config.yml file.")
parser.add_argument("-s", "--source", help="Path to the project source directory.")
parser.add_argument("-c", "--config", help="Path to the TTool config.yml file.")

args = parser.parse_args()

if not os.path.isdir(args.source):
_log_error("The path to the project source directory is not valid.")
sys.exit()
if not os.path.isfile(args.config):
_log_error("The path to the config file is not valid.")
sys.exit()

main(source_path=args.source, config_path=args.config)