Skip to content

Commit

Permalink
add docstring checks to pyproject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
jonchang committed Jan 14, 2025
1 parent 0d8b79f commit 0ef1250
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 194 deletions.
7 changes: 2 additions & 5 deletions OCR/benchmark_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ def main():


def run_segmentation_and_ocr(args):
"""
Runs segmentation and OCR processing.
"""Runs segmentation and OCR processing.
Returns OCR results with processing time.
"""

Check failure on line 51 in OCR/benchmark_main.py

View workflow job for this annotation

GitHub Actions / python

Ruff (D205)

OCR/benchmark_main.py:49:5: D205 1 blank line required between summary line and description

model = None

if args.model == "tesseract":
Expand All @@ -70,8 +68,7 @@ def run_segmentation_and_ocr(args):


def run_metrics_analysis(args, ocr_results):
"""
Runs metrics analysis based on OCR output and ground truth.
"""Runs metrics analysis based on OCR output and ground truth.
Uses OCR results to capture time values if available.
"""

Check failure on line 73 in OCR/benchmark_main.py

View workflow job for this annotation

GitHub Actions / python

Ruff (D205)

OCR/benchmark_main.py:71:5: D205 1 blank line required between summary line and description
metrics_analysis = BatchMetricsAnalysis(args.output_folder, args.ground_truth_folder, args.csv_output_folder)
Expand Down
2 changes: 1 addition & 1 deletion OCR/ocr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,5 @@ async def image_to_text(


def start():
"""Launched with `poetry run start` at root level"""
"""Launched with `poetry run start` at root level."""
uvicorn.run(app, host="0.0.0.0", port=8000, reload=False)
64 changes: 56 additions & 8 deletions OCR/ocr/services/alignment/backends/four_point_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.
"""
"""Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image."""

from pathlib import Path
import functools
Expand All @@ -10,19 +8,50 @@


class FourPointTransform:
"""A class to perform a four-point perspective transformation on an image.
This involves detecting the largest quadrilateral in the image and transforming it
to a standard rectangular form using a perspective warp.
Attributes:
image (np.ndarray): The input image as a NumPy array.
"""
def __init__(self, image: Path | np.ndarray):
"""Initializes the FourPointTransform object with an image.
The image can either be provided as a file path (Path) or a NumPy array.
Args:
image (Path | np.ndarray): The input image, either as a path to a file or as a NumPy array.
"""
if isinstance(image, np.ndarray):
self.image = image
else:
self.image = cv.imread(str(image))

@classmethod
def align(self, source_image, template_image):
def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""Aligns a source image to a template image using the four-point transform.
Args:
source_image (np.ndarray): The source image to be aligned.
template_image (np.ndarray): The template image to align to.
Returns:
np.ndarray: The transformed image.
"""
return FourPointTransform(source_image).dewarp()

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
"Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left."
"""Reorders the points of a quadrilateral from an unordered 4x2 array to a specific order of top-left, top-right, bottom-right, and bottom-left.
Args:
quadrilateral (np.ndarray): A 4x2 array representing the vertices of a quadrilateral.
Returns:
np.ndarray: A 4x2 array with the points ordered as [top-left, top-right, bottom-right, bottom-left].
"""
quadrilateral = quadrilateral.reshape(4, 2)
output_quad = np.zeros([4, 2]).astype(np.float32)
s = quadrilateral.sum(axis=1)
Expand All @@ -33,19 +62,38 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
output_quad[3] = quadrilateral[np.argmax(diff)]
return output_quad

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
def find_largest_contour(self) -> np.ndarray:
"""Finds the largest contour in the image by computing the contours and selecting the one with the greatest area.
Returns:
np.ndarray: The largest contour found in the image.
"""
contours, _ = cv.findContours(
cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE
)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
"""Simplify to a polygon with (hopefully four) vertices."""
"""Simplifies a given contour to a polygon with a reduced number of vertices, ideally four.
Args:
contour (np.ndarray): The contour to simplify.
Returns:
np.ndarray: The simplified polygon.
"""
perimeter = cv.arcLength(contour, True)
return cv.approxPolyDP(contour, 0.01 * perimeter, True)

def dewarp(self) -> np.ndarray:
"""Performs a four-point perspective transform to "dewarp" the image.
This involves detecting the largest quadrilateral, simplifying it to a polygon, and
applying a perspective warp to straighten the image into a rectangle.
Returns:
np.ndarray: The perspective-transformed (dewarped) image.
"""
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

Expand Down
105 changes: 88 additions & 17 deletions OCR/ocr/services/alignment/backends/image_homography.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
"""Aligns two images using image homography algorithms."""
from pathlib import Path

import numpy as np
import cv2 as cv


class ImageHomography:
"""A class to align two images using homography techniques.
Uses Scale-Invariant Feature Transform (SIFT) algorithm to detect keypoints and
compute descriptors for image matching, and then estimates a homography
transformation matrix to align the source image with a template image.
Attributes:
template (np.ndarray): The template image to align against, either as a path or a NumPy array.
match_ratio (float): The ratio used for Lowe's ratio test to filter good matches.
_sift (cv.SIFT): The SIFT detector used to find keypoints and descriptors.
"""
def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
"""Initializes the ImageHomography object with a template image.
Optionally include a match ratio for filtering descriptor matches; this must be between 0 and 1.
Args:
template (Path | np.ndarray): The template image, either as a file path or a NumPy array.
match_ratio (float, optional): The ratio threshold for Lowe's ratio test. Default is 0.3.
Raises:
ValueError: If `match_ratio` is not between 0 and 1.
"""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

Expand All @@ -18,24 +40,64 @@ def __init__(self, template: Path | np.ndarray, match_ratio=0.3):
self._sift = cv.SIFT_create()

@classmethod
def align(self, source_image, template_image):
def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""Aligns a source image to a template image.
Args:
source_image (np.ndarray): The source image to align.
template_image (np.ndarray): The template image to align to.
Returns:
np.ndarray: The aligned source image.
"""
return ImageHomography(template_image).transform_homography(source_image)

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
"""Calibrates the match ratio using a self-similarity metric (not implemented).
Raises:
NotImplementedError: Since this method is not implemented.
"""
raise NotImplementedError

def compute_descriptors(self, img):
"""Compute SIFT descriptors for a target `img`."""
def compute_descriptors(self, img: np.ndarray):
"""Computes the SIFT descriptors for a given image.
These descriptors represent distinctive features in the image that can be used for matching.
Args:
img (np.ndarray): The image for which to compute descriptors.
Returns:
tuple: A 2-element tuple containing the keypoints and their corresponding descriptors.
"""
return self._sift.detectAndCompute(img, None)

def knn_match(self, descriptor_template, descriptor_query):
"""Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image."""
"""Performs k-nearest neighbors matching (k=2) between descriptors to find best homography matches.
Args:
descriptor_template (np.ndarray): The SIFT descriptors from the template image.
descriptor_query (np.ndarray): The SIFT descriptors from the query image.
Returns:
list: A list of k-nearest neighbor matches between the template and query descriptors.
"""
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED)
return matcher.knnMatch(descriptor_template, descriptor_query, 2)

def estimate_transform_matrix(self, other):
"Estimate the transformation matrix based on homography."
def estimate_transform_matrix(self, other: np.ndarray) -> np.ndarray:
"""Estimates the transformation matrix between the template image and another image.
This function detects keypoints and descriptors, matches them using k-nearest neighbors,
and applies Lowe's ratio test to filter for quality matches.
Args:
other (np.ndarray): The image to estimate the transformation matrix against.
Returns:
np.ndarray: The homography matrix that transforms the other image to align with the template image.
"""
# find the keypoints and descriptors with SIFT
kp1, descriptors1 = self.compute_descriptors(self.template)
kp2, descriptors2 = self.compute_descriptors(other)
Expand All @@ -55,17 +117,26 @@ def estimate_transform_matrix(self, other):
M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)
return M

def transform_homography(self, other, min_axis=100, matrix=None):
"""
Run the image homography pipeline against a query image.
def transform_homography(self, other: np.ndarray, min_axis=100, matrix=None) -> np.ndarray:
"""Run the full image homography pipeline against a query image.
Parameters:
min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform.
If the input image is under the axis limits, return the original input image unchanged.
matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be
estimated with `estimate_transform_matrix`.
"""
If the size of the `other` image is smaller than the minimum axis length `min_axis`,
the image is returned unchanged.
If a transformation matrix is provided, it is used directly; otherwise, the matrix is
estimated using `estimate_transform_matrix`.
Args:
other (np.ndarray): The image to be transformed.
min_axis (int, optional): The minimum axis length (in pixels) to attempt the homography transform.
If the image is smaller, it will be returned unchanged. Default is 100.
matrix (np.ndarray, optional): The homography transformation matrix to apply. If not provided,
it will be estimated.
Returns:
np.ndarray: The transformed image if homography was applied, or the original image if it is
smaller than the minimum axis size.
"""
if other.shape[0] < min_axis and other.shape[1] < min_axis:
return other

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Perspective transforms a base image between 10% and 90% distortion.
"""
"""Perspective transforms a base image between 10% and 90% distortion."""

from pathlib import Path

Expand All @@ -16,9 +14,7 @@ def __init__(self, image: Path):
self.image = Image.open(image)

def make_transform(self, distortion_scale: float) -> object:
"""
Create a transformation matrix for a random perspective transform.
"""
"""Create a transformation matrix for a random perspective transform."""
import torch

# From torchvision. BSD 3-clause
Expand Down
3 changes: 1 addition & 2 deletions OCR/ocr/services/alignment/image_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def __init__(self, aligner=ImageHomography):
self.aligner = aligner

def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray:
"""
Aligns an image using the specified image alignment backend.
"""Aligns an image using the specified image alignment backend.
source_image: the image to be aligned, as a numpy ndarray.
template_image: the image that `source_image` will be aligned against, as a numpy ndarray.
Expand Down
18 changes: 5 additions & 13 deletions OCR/ocr/services/batch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def __init__(self, ocr_folder, ground_truth_folder, csv_output_folder):
os.makedirs(self.csv_output_folder, exist_ok=True)

def calculate_batch_metrics(self, ocr_results=None):
"""
Processes OCR and ground truth files and saves individual CSVs.
"""Processes OCR and ground truth files and saves individual CSVs.
Ensures only matching files are processed.
"""
print(f"Loading OCR files from: {self.ocr_folder}")
Expand Down Expand Up @@ -69,9 +68,7 @@ def calculate_batch_metrics(self, ocr_results=None):

@staticmethod
def save_metrics_to_csv(metrics, total_metrics, file_path):
"""
Saves individual and total metrics to a CSV file, including time taken.
"""
"""Saves individual and total metrics to a CSV file, including time taken."""
print(metrics)
metric_keys = list(metrics[0].keys())
total_metric_keys = list(total_metrics.keys())
Expand All @@ -93,9 +90,7 @@ def save_metrics_to_csv(metrics, total_metrics, file_path):

@staticmethod
def save_problematic_segments_to_csv(segments, file_path):
"""
Saves problematic segments (Levenshtein distance >= 1) to a CSV file.
"""
"""Saves problematic segments (Levenshtein distance >= 1) to a CSV file."""
if not segments:
print("No problematic segments found.")
return
Expand All @@ -110,9 +105,7 @@ def save_problematic_segments_to_csv(segments, file_path):
print(f"Problematic segments saved to {file_path}")

def extract_problematic_segments(self, metrics, ocr_file, problematic_segments):
"""
Extracts segments with Levenshtein distance >= 1 and stores them.
"""
"""Extracts segments with Levenshtein distance >= 1 and stores them."""
for metric in metrics:
if metric["levenshtein_distance"] >= 1:
problematic_segments.append(
Expand All @@ -128,8 +121,7 @@ def extract_problematic_segments(self, metrics, ocr_file, problematic_segments):

@staticmethod
def get_files_in_directory(directory):
"""
Returns a sorted list of files in the specified directory.
"""Returns a sorted list of files in the specified directory.
Assumes that files are named consistently for OCR and ground truth.
"""
try:
Expand Down
12 changes: 3 additions & 9 deletions OCR/ocr/services/batch_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def __init__(self, image_folder, segmentation_template, labels_path, output_fold
os.makedirs(self.output_folder, exist_ok=True)

def process_images(self):
"""
Processes all images and returns OCR results with time taken.
"""
"""Processes all images and returns OCR results with time taken."""
segmenter = ImageSegmenter()
ocr = self.model
results = []
Expand Down Expand Up @@ -56,9 +54,7 @@ def process_images(self):
return results

def segment_ocr_image(self, segmenter, ocr, image_path, image_file):
"""
Segments the image and runs OCR, returning results and time taken.
"""
"""Segments the image and runs OCR, returning results and time taken."""
start_time = time.time()

# Segment the image and run OCR
Expand All @@ -78,9 +74,7 @@ def segment_ocr_image(self, segmenter, ocr, image_path, image_file):
return ocr_result, time_taken

def write_times_to_csv(self, time_dict, csv_output_path):
"""
Writes the time taken for each file to a CSV.
"""
"""Writes the time taken for each file to a CSV."""
csv_file_path = os.path.join(csv_output_path, "time_taken.csv")

with open(csv_file_path, "w", newline="") as csv_file:
Expand Down
Loading

0 comments on commit 0ef1250

Please sign in to comment.