Skip to content

Commit

Permalink
fix for error tabular output table
Browse files Browse the repository at this point in the history
  • Loading branch information
Abellegese committed Jan 16, 2025
1 parent 5fa3bd8 commit e283d62
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 60 deletions.
14 changes: 6 additions & 8 deletions ersilia/cli/commands/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import types

import click
Expand Down Expand Up @@ -69,16 +68,15 @@ def run(input, output, batch_size, as_table):
batch_size=batch_size,
track_run=track_runs,
)
iter_values = []
if isinstance(result, types.GeneratorType):
for result in mdl.run(input=input, output=output, batch_size=batch_size):
if result is not None:
formatted = json.dumps(result, indent=4)
if as_table:
print_result_table(formatted)
else:
echo(formatted)
else:
echo("Something went wrong", fg="red")
iter_values.append(result)
if as_table:
print_result_table(iter_values)
else:
echo(iter_values)
else:
if as_table:
print_result_table(result)
Expand Down
17 changes: 17 additions & 0 deletions ersilia/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,20 @@ def bashrc_cli_snippet(overwrite=True):
f.write(text)
with open(fn, "a+") as f:
f.write(snippet)


OUTPUT_DATASTRUCTURE = {
"Single": lambda x: isinstance(x, list) and len(x) == 1,
"List": lambda x: isinstance(x, list)
and len(x) > 1
and all(isinstance(item, (str, int, float)) for item in x),
"Flexible List": lambda x: isinstance(x, list)
and all(isinstance(item, (str, int, float)) for item in x),
"Matrix": lambda x: isinstance(x, list)
and all(
isinstance(row, list)
and all(isinstance(item, (str, int, float)) for item in row)
for row in x
),
"Serializable Object": lambda x: isinstance(x, dict),
}
164 changes: 112 additions & 52 deletions ersilia/utils/terminal.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import csv
import io
import json
import os
import shutil
import subprocess

from .logging import logger

try:
from inputimeout import TimeoutOccurred, inputimeout
except:
inputimeout = None
TimeoutOccurred = None

from ..default import VERBOSE_FILE
from ..default import OUTPUT_DATASTRUCTURE, VERBOSE_FILE
from ..utils.logging import make_temp_dir
from ..utils.session import get_session_dir
from .hdf5 import Hdf5DataLoader


def is_quiet():
Expand Down Expand Up @@ -159,75 +157,137 @@ def yes_no_input(prompt, default_answer, timeout=5):
return True


def _flatten_data(json_data):
flattened = []
for item in json_data:
row = {}
for key, value in item.items():
if isinstance(value, dict):
for sub_key, sub_value in value.items():
if sub_key == "text":
continue
row[sub_key] = _handle_supported_structures(sub_value)
elif key != "text":
row[key] = _handle_supported_structures(value)
flattened.append(row)
return flattened


def _handle_supported_structures(value):
for dtype, checker in OUTPUT_DATASTRUCTURE.items():
if checker(value):
if dtype == "Single":
return str(value[0])
elif dtype == "List" or dtype == "Flexible List":
return ", ".join(map(str, value))
elif dtype == "Matrix":
return "; ".join(", ".join(map(str, row)) for row in value)
elif dtype == "Serializable Object":
return json.dumps(value)
return str(value)


def _read_hdf5_with_loader(file_path):
loader = Hdf5DataLoader()
loader.load(file_path)

data = []
for i, key in enumerate(loader.keys):
row = {
"Key": key,
"Input": loader.inputs[i] if i < len(loader.inputs) else None,
"Value": loader.values[i] if i < len(loader.values) else None,
"Feature": loader.features[i] if i < len(loader.features) else None,
}
data.append(row)
return data


def _read_csv(file_path):
with open(file_path, mode="r") as file:
reader = csv.DictReader(file)
return [dict(row) for row in reader]


def print_result_table(data):
"""
Print a result table from CSV or JSON-like data.
Parameters
----------
data : str or list
The path to a CSV file or JSON-like data.
Print a result table with solid borders from JSON, CSV, or HDF5-like data.
Supports formatted JSON strings.
"""
HEADER_COLOR = "\033[95m"
ROW_COLOR = "\033[94m"
RESET_COLOR = "\033[0m"
if isinstance(data, str) and os.path.isfile(data):
with open(data, mode="r") as file:
reader = csv.DictReader(file)
data = [dict(row) for row in reader]
elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
COLOR, BORDER_CHAR, VERTICAL_BORDER = "\033[0m", "━", "┃"

if isinstance(data, str):
try:
parsed_data = json.loads(data)
if isinstance(parsed_data, list) and all(
isinstance(item, dict) for item in parsed_data
):
data = parsed_data
else:
raise ValueError(
"The JSON string must represent a list of dictionaries."
)
except json.JSONDecodeError:
if os.path.isfile(data):
if data.endswith(".json"):
with open(data, mode="r") as file:
data = json.load(file)
elif data.endswith(".csv"):
data = _read_csv(data)
elif data.endswith((".h5", ".hdf5")):
data = _read_hdf5_with_loader(data)
else:
raise ValueError(f"Unsupported file type: {data}")
else:
raise ValueError(
f"Provided string is neither valid JSON nor a file path. {data}"
)

if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
if data[0].get("input") or data[0].get("output"):
data = _flatten_data(data)

headers = list(data[0].keys())

column_widths = {
header: max(len(header), max(len(str(row[header])) for row in data)) + 5
header: max(len(header), max(len(str(row.get(header, ""))) for row in data))
+ 2
for header in headers
}

def format_row(row_data, is_header=False):
if is_header:
return (
HEADER_COLOR
+ " | ".join(
f"{header.ljust(column_widths[header])}" for header in headers
COLOR
+ VERTICAL_BORDER
+ VERTICAL_BORDER.join(
f" {header.ljust(column_widths[header])} " for header in headers
)
+ RESET_COLOR
+ VERTICAL_BORDER
+ COLOR
)
else:
return (
ROW_COLOR
+ " | ".join(
f"{str(row_data[header]).ljust(column_widths[header])}"
COLOR
+ VERTICAL_BORDER
+ VERTICAL_BORDER.join(
f" {str(row_data.get(header, '')).ljust(column_widths[header])} "
for header in headers
)
+ RESET_COLOR
+ VERTICAL_BORDER
+ COLOR
)

separator = "-" * (sum(column_widths.values()) + (3 * len(headers) - 1))
print(separator)
total_width = sum(column_widths.values()) + (3 * len(headers))
border_line = BORDER_CHAR * total_width

print(border_line)
print(format_row(headers, is_header=True))
print(separator)
print(border_line)
for row in data:
print(format_row(row))
print(separator)
print(border_line)
else:
logger.debug(
"Invalid input data format. Please provide either a CSV file path or JSON-like data."
print(
f"Invalid input data format. Please provide valid JSON, CSV, or HDF5 data.{data}"
)


def read_csv_from_string(csv_string):
"""
Read CSV data from a string.
Parameters
----------
csv_string : str
The CSV data as a string.
Returns
-------
list
A list of dictionaries representing the CSV data.
"""
f = io.StringIO(csv_string)
reader = csv.DictReader(f)
return [dict(row) for row in reader]

0 comments on commit e283d62

Please sign in to comment.