diff --git a/OCR/benchmark_main.py b/OCR/benchmark_main.py index 203d151f..a5b245fe 100644 --- a/OCR/benchmark_main.py +++ b/OCR/benchmark_main.py @@ -46,11 +46,9 @@ def main(): def run_segmentation_and_ocr(args): - """ - Runs segmentation and OCR processing. + """Runs segmentation and OCR processing. Returns OCR results with processing time. """ - model = None if args.model == "tesseract": @@ -70,8 +68,7 @@ def run_segmentation_and_ocr(args): def run_metrics_analysis(args, ocr_results): - """ - Runs metrics analysis based on OCR output and ground truth. + """Runs metrics analysis based on OCR output and ground truth. Uses OCR results to capture time values if available. """ metrics_analysis = BatchMetricsAnalysis(args.output_folder, args.ground_truth_folder, args.csv_output_folder) diff --git a/OCR/ocr/api.py b/OCR/ocr/api.py index 1e9bfb33..2c961bc1 100644 --- a/OCR/ocr/api.py +++ b/OCR/ocr/api.py @@ -149,5 +149,5 @@ async def image_to_text( def start(): - """Launched with `poetry run start` at root level""" + """Launched with `poetry run start` at root level.""" uvicorn.run(app, host="0.0.0.0", port=8000, reload=False) diff --git a/OCR/ocr/services/alignment/backends/four_point_transform.py b/OCR/ocr/services/alignment/backends/four_point_transform.py index cb0954b4..a3535aea 100644 --- a/OCR/ocr/services/alignment/backends/four_point_transform.py +++ b/OCR/ocr/services/alignment/backends/four_point_transform.py @@ -1,6 +1,4 @@ -""" -Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image. -""" +"""Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.""" from pathlib import Path import functools @@ -10,19 +8,50 @@ class FourPointTransform: + """A class to perform a four-point perspective transformation on an image. + + This involves detecting the largest quadrilateral in the image and transforming it + to a standard rectangular form using a perspective warp. + + Attributes: + image (np.ndarray): The input image as a NumPy array. + """ def __init__(self, image: Path | np.ndarray): + """Initializes the FourPointTransform object with an image. + + The image can either be provided as a file path (Path) or a NumPy array. + + Args: + image (Path | np.ndarray): The input image, either as a path to a file or as a NumPy array. + """ if isinstance(image, np.ndarray): self.image = image else: self.image = cv.imread(str(image)) @classmethod - def align(self, source_image, template_image): + def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: + """Aligns a source image to a template image using the four-point transform. + + Args: + source_image (np.ndarray): The source image to be aligned. + template_image (np.ndarray): The template image to align to. + + Returns: + np.ndarray: The transformed image. + """ return FourPointTransform(source_image).dewarp() @staticmethod def _order_points(quadrilateral: np.ndarray) -> np.ndarray: - "Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left." + """Reorders the points of a quadrilateral from an unordered 4x2 array to a specific order of top-left, top-right, bottom-right, and bottom-left. + + Args: + quadrilateral (np.ndarray): A 4x2 array representing the vertices of a quadrilateral. + + Returns: + np.ndarray: A 4x2 array with the points ordered as [top-left, top-right, bottom-right, bottom-left]. + """ quadrilateral = quadrilateral.reshape(4, 2) output_quad = np.zeros([4, 2]).astype(np.float32) s = quadrilateral.sum(axis=1) @@ -33,19 +62,38 @@ def _order_points(quadrilateral: np.ndarray) -> np.ndarray: output_quad[3] = quadrilateral[np.argmax(diff)] return output_quad - def find_largest_contour(self): - """Compute contours for an image and find the biggest one by area.""" + def find_largest_contour(self) -> np.ndarray: + """Finds the largest contour in the image by computing the contours and selecting the one with the greatest area. + + Returns: + np.ndarray: The largest contour found in the image. + """ contours, _ = cv.findContours( cv.cvtColor(self.image, cv.COLOR_BGR2GRAY), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE ) return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours) def simplify_polygon(self, contour): - """Simplify to a polygon with (hopefully four) vertices.""" + """Simplifies a given contour to a polygon with a reduced number of vertices, ideally four. + + Args: + contour (np.ndarray): The contour to simplify. + + Returns: + np.ndarray: The simplified polygon. + """ perimeter = cv.arcLength(contour, True) return cv.approxPolyDP(contour, 0.01 * perimeter, True) def dewarp(self) -> np.ndarray: + """Performs a four-point perspective transform to "dewarp" the image. + + This involves detecting the largest quadrilateral, simplifying it to a polygon, and + applying a perspective warp to straighten the image into a rectangle. + + Returns: + np.ndarray: The perspective-transformed (dewarped) image. + """ biggest_contour = self.find_largest_contour() simplified = self.simplify_polygon(biggest_contour) diff --git a/OCR/ocr/services/alignment/backends/image_homography.py b/OCR/ocr/services/alignment/backends/image_homography.py index 5ab1ddfb..00a4af11 100644 --- a/OCR/ocr/services/alignment/backends/image_homography.py +++ b/OCR/ocr/services/alignment/backends/image_homography.py @@ -1,3 +1,4 @@ +"""Aligns two images using image homography algorithms.""" from pathlib import Path import numpy as np @@ -5,8 +6,29 @@ class ImageHomography: + """A class to align two images using homography techniques. + + Uses Scale-Invariant Feature Transform (SIFT) algorithm to detect keypoints and + compute descriptors for image matching, and then estimates a homography + transformation matrix to align the source image with a template image. + + Attributes: + template (np.ndarray): The template image to align against, either as a path or a NumPy array. + match_ratio (float): The ratio used for Lowe's ratio test to filter good matches. + _sift (cv.SIFT): The SIFT detector used to find keypoints and descriptors. + """ def __init__(self, template: Path | np.ndarray, match_ratio=0.3): - """Initialize the image homography pipeline with a `template` image.""" + """Initializes the ImageHomography object with a template image. + + Optionally include a match ratio for filtering descriptor matches; this must be between 0 and 1. + + Args: + template (Path | np.ndarray): The template image, either as a file path or a NumPy array. + match_ratio (float, optional): The ratio threshold for Lowe's ratio test. Default is 0.3. + + Raises: + ValueError: If `match_ratio` is not between 0 and 1. + """ if match_ratio >= 1 or match_ratio <= 0: raise ValueError("`match_ratio` must be between 0 and 1") @@ -18,24 +40,64 @@ def __init__(self, template: Path | np.ndarray, match_ratio=0.3): self._sift = cv.SIFT_create() @classmethod - def align(self, source_image, template_image): + def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: + """Aligns a source image to a template image. + + Args: + source_image (np.ndarray): The source image to align. + template_image (np.ndarray): The template image to align to. + + Returns: + np.ndarray: The aligned source image. + """ return ImageHomography(template_image).transform_homography(source_image) def estimate_self_similarity(self): - """Calibrate `match_ratio` using a self-similarity metric.""" + """Calibrates the match ratio using a self-similarity metric (not implemented). + + Raises: + NotImplementedError: Since this method is not implemented. + """ raise NotImplementedError - def compute_descriptors(self, img): - """Compute SIFT descriptors for a target `img`.""" + def compute_descriptors(self, img: np.ndarray): + """Computes the SIFT descriptors for a given image. + + These descriptors represent distinctive features in the image that can be used for matching. + + Args: + img (np.ndarray): The image for which to compute descriptors. + + Returns: + tuple: A 2-element tuple containing the keypoints and their corresponding descriptors. + """ return self._sift.detectAndCompute(img, None) def knn_match(self, descriptor_template, descriptor_query): - """Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image.""" + """Performs k-nearest neighbors matching (k=2) between descriptors to find best homography matches. + + Args: + descriptor_template (np.ndarray): The SIFT descriptors from the template image. + descriptor_query (np.ndarray): The SIFT descriptors from the query image. + + Returns: + list: A list of k-nearest neighbor matches between the template and query descriptors. + """ matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED) return matcher.knnMatch(descriptor_template, descriptor_query, 2) - def estimate_transform_matrix(self, other): - "Estimate the transformation matrix based on homography." + def estimate_transform_matrix(self, other: np.ndarray) -> np.ndarray: + """Estimates the transformation matrix between the template image and another image. + + This function detects keypoints and descriptors, matches them using k-nearest neighbors, + and applies Lowe's ratio test to filter for quality matches. + + Args: + other (np.ndarray): The image to estimate the transformation matrix against. + + Returns: + np.ndarray: The homography matrix that transforms the other image to align with the template image. + """ # find the keypoints and descriptors with SIFT kp1, descriptors1 = self.compute_descriptors(self.template) kp2, descriptors2 = self.compute_descriptors(other) @@ -55,17 +117,26 @@ def estimate_transform_matrix(self, other): M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0) return M - def transform_homography(self, other, min_axis=100, matrix=None): - """ - Run the image homography pipeline against a query image. + def transform_homography(self, other: np.ndarray, min_axis=100, matrix=None) -> np.ndarray: + """Run the full image homography pipeline against a query image. - Parameters: - min_axis: minimum x- and y-axis length, in pixels, to attempt to do a homography transform. - If the input image is under the axis limits, return the original input image unchanged. - matrix: if specified, a transformation matrix to warp the input image. Otherwise this will be - estimated with `estimate_transform_matrix`. - """ + If the size of the `other` image is smaller than the minimum axis length `min_axis`, + the image is returned unchanged. + If a transformation matrix is provided, it is used directly; otherwise, the matrix is + estimated using `estimate_transform_matrix`. + + Args: + other (np.ndarray): The image to be transformed. + min_axis (int, optional): The minimum axis length (in pixels) to attempt the homography transform. + If the image is smaller, it will be returned unchanged. Default is 100. + matrix (np.ndarray, optional): The homography transformation matrix to apply. If not provided, + it will be estimated. + + Returns: + np.ndarray: The transformed image if homography was applied, or the original image if it is + smaller than the minimum axis size. + """ if other.shape[0] < min_axis and other.shape[1] < min_axis: return other diff --git a/OCR/ocr/services/alignment/backends/random_perspective_transform.py b/OCR/ocr/services/alignment/backends/random_perspective_transform.py index 3697d1df..20508cd3 100644 --- a/OCR/ocr/services/alignment/backends/random_perspective_transform.py +++ b/OCR/ocr/services/alignment/backends/random_perspective_transform.py @@ -1,6 +1,4 @@ -""" -Perspective transforms a base image between 10% and 90% distortion. -""" +"""Perspective transforms a base image between 10% and 90% distortion.""" from pathlib import Path @@ -16,9 +14,7 @@ def __init__(self, image: Path): self.image = Image.open(image) def make_transform(self, distortion_scale: float) -> object: - """ - Create a transformation matrix for a random perspective transform. - """ + """Create a transformation matrix for a random perspective transform.""" import torch # From torchvision. BSD 3-clause diff --git a/OCR/ocr/services/alignment/image_alignment.py b/OCR/ocr/services/alignment/image_alignment.py index d52472f4..bfc3673c 100644 --- a/OCR/ocr/services/alignment/image_alignment.py +++ b/OCR/ocr/services/alignment/image_alignment.py @@ -8,8 +8,7 @@ def __init__(self, aligner=ImageHomography): self.aligner = aligner def align(self, source_image: np.ndarray, template_image: np.ndarray) -> np.ndarray: - """ - Aligns an image using the specified image alignment backend. + """Aligns an image using the specified image alignment backend. source_image: the image to be aligned, as a numpy ndarray. template_image: the image that `source_image` will be aligned against, as a numpy ndarray. diff --git a/OCR/ocr/services/batch_metrics.py b/OCR/ocr/services/batch_metrics.py index ef5b0fb4..a36a7aee 100644 --- a/OCR/ocr/services/batch_metrics.py +++ b/OCR/ocr/services/batch_metrics.py @@ -12,8 +12,7 @@ def __init__(self, ocr_folder, ground_truth_folder, csv_output_folder): os.makedirs(self.csv_output_folder, exist_ok=True) def calculate_batch_metrics(self, ocr_results=None): - """ - Processes OCR and ground truth files and saves individual CSVs. + """Processes OCR and ground truth files and saves individual CSVs. Ensures only matching files are processed. """ print(f"Loading OCR files from: {self.ocr_folder}") @@ -69,9 +68,7 @@ def calculate_batch_metrics(self, ocr_results=None): @staticmethod def save_metrics_to_csv(metrics, total_metrics, file_path): - """ - Saves individual and total metrics to a CSV file, including time taken. - """ + """Saves individual and total metrics to a CSV file, including time taken.""" print(metrics) metric_keys = list(metrics[0].keys()) total_metric_keys = list(total_metrics.keys()) @@ -93,9 +90,7 @@ def save_metrics_to_csv(metrics, total_metrics, file_path): @staticmethod def save_problematic_segments_to_csv(segments, file_path): - """ - Saves problematic segments (Levenshtein distance >= 1) to a CSV file. - """ + """Saves problematic segments (Levenshtein distance >= 1) to a CSV file.""" if not segments: print("No problematic segments found.") return @@ -110,9 +105,7 @@ def save_problematic_segments_to_csv(segments, file_path): print(f"Problematic segments saved to {file_path}") def extract_problematic_segments(self, metrics, ocr_file, problematic_segments): - """ - Extracts segments with Levenshtein distance >= 1 and stores them. - """ + """Extracts segments with Levenshtein distance >= 1 and stores them.""" for metric in metrics: if metric["levenshtein_distance"] >= 1: problematic_segments.append( @@ -128,8 +121,7 @@ def extract_problematic_segments(self, metrics, ocr_file, problematic_segments): @staticmethod def get_files_in_directory(directory): - """ - Returns a sorted list of files in the specified directory. + """Returns a sorted list of files in the specified directory. Assumes that files are named consistently for OCR and ground truth. """ try: diff --git a/OCR/ocr/services/batch_segmentation.py b/OCR/ocr/services/batch_segmentation.py index d2d80f34..5685312c 100644 --- a/OCR/ocr/services/batch_segmentation.py +++ b/OCR/ocr/services/batch_segmentation.py @@ -19,9 +19,7 @@ def __init__(self, image_folder, segmentation_template, labels_path, output_fold os.makedirs(self.output_folder, exist_ok=True) def process_images(self): - """ - Processes all images and returns OCR results with time taken. - """ + """Processes all images and returns OCR results with time taken.""" segmenter = ImageSegmenter() ocr = self.model results = [] @@ -56,9 +54,7 @@ def process_images(self): return results def segment_ocr_image(self, segmenter, ocr, image_path, image_file): - """ - Segments the image and runs OCR, returning results and time taken. - """ + """Segments the image and runs OCR, returning results and time taken.""" start_time = time.time() # Segment the image and run OCR @@ -78,9 +74,7 @@ def segment_ocr_image(self, segmenter, ocr, image_path, image_file): return ocr_result, time_taken def write_times_to_csv(self, time_dict, csv_output_path): - """ - Writes the time taken for each file to a CSV. - """ + """Writes the time taken for each file to a CSV.""" csv_file_path = os.path.join(csv_output_path, "time_taken.csv") with open(csv_file_path, "w", newline="") as csv_file: diff --git a/OCR/ocr/services/image_ocr.py b/OCR/ocr/services/image_ocr.py index c8d91f48..e9fdb784 100644 --- a/OCR/ocr/services/image_ocr.py +++ b/OCR/ocr/services/image_ocr.py @@ -1,3 +1,4 @@ +"""Module for OCR using a transformers-based OCR model.""" from collections.abc import Iterator from transformers import TrOCRProcessor, VisionEncoderDecoderModel @@ -7,14 +8,38 @@ class ImageOCR: + """A class for OCR using the transformers-based models. + + Defaults to using the Microsoft TrOCR model from Hugging Face's transformers library. + + Attributes: + processor (TrOCRProcessor): Processor for TrOCR model that prepares images for OCR. + model (VisionEncoderDecoderModel): Pre-trained TrOCR model for extracting text from images. + """ def __init__(self, model="microsoft/trocr-large-printed"): + """Initializes the ImageOCR class with the specified OCR model. + + Args: + model (str, optional): The name of the pre-trained model to use. Default is "microsoft/trocr-large-printed". + + See Also: + * https://huggingface.co/microsoft/trocr-large-printed + """ self.processor = TrOCRProcessor.from_pretrained(model) self.model = VisionEncoderDecoderModel.from_pretrained(model) @staticmethod def compute_line_angle(lines: list) -> Iterator[float]: - """ - Takes the output of cv.HoughLinesP (in x1, y1, x2, y2 format) and computes the angle in degrees based on these endpoints. + """Computes the angle in degrees of the lines detected by the Hough transform, based on their endpoints. + + This method processes the output of `cv.HoughLinesP` (lines in (x1, y1, x2, y2) format) and computes the angle + between each line and the horizontal axis. + + Args: + lines (list): A list of lines represented as a list or tuple of endpoints [x1, y1, x2, y2]. + + Yields: + float: The angle (in degrees) of each line with respect to the horizontal axis. """ for line in lines: start = line[0][0:2] @@ -24,8 +49,16 @@ def compute_line_angle(lines: list) -> Iterator[float]: @staticmethod def merge_bounding_boxes(boxes: list) -> Iterator[list]: - """ - Merges overlapping boxes, passed in (x, y, w, h) format. + """Merges overlapping bounding boxes into a single bounding box. + + Given a list of bounding boxes in (x, y, w, h) format, this function merges overlapping boxes + into one larger box. + + Args: + boxes (list): A list of bounding boxes, where each box is represented as a list or tuple [x, y, w, h]. + + Yields: + list: Merged bounding boxes, represented in [x, y, w, h] format. """ if not boxes: return [] @@ -53,9 +86,17 @@ def merge_bounding_boxes(boxes: list) -> Iterator[list]: yield [current[0], current[1], current[2] - current[0], current[3] - current[1]] def identify_blocks(self, input_image: np.ndarray, kernel: np.ndarray): - """ - Given an input image and a morphological operation kernel, returns unique (non-overlapping) - bounding boxes of potential text regions. + """Identifies potential text blocks in an image by applying morphological operations. + + The function uses the input image, applies thresholding and dilation, and then finds contours to identify + potential text blocks. It then merges overlapping bounding boxes into larger ones. + + Args: + input_image (np.ndarray): The input image to process. + kernel (np.ndarray): The kernel used for morphological operations (dilation). + + Returns: + Iterator[list]: An iterator of merged bounding boxes, each represented as [x, y, w, h]. """ # Invert threshold `input_image` and dilate using `kernel` to "expand" the size of text blocks _, thresh = cv.threshold(cv.cvtColor(input_image, cv.COLOR_BGR2GRAY), 128, 255, cv.THRESH_BINARY_INV) @@ -68,14 +109,17 @@ def identify_blocks(self, input_image: np.ndarray, kernel: np.ndarray): return self.merge_bounding_boxes([cv.boundingRect(contour) for contour in contours]) def deskew_image_text(self, image: np.ndarray, line_length_prop=0.5, max_skew_angle=10) -> np.ndarray: - """ - Deskew an image using Hough transforms to detect lines. + """Deskew an image using Hough transforms to detect lines and rotating the image to correct any skew. Since even small-angled skews can compromise the line segmentation algorithm, this is needed as a preprocessing step. - line_length_prop: typical line length as a fraction of the horizontal size of an image. - max_skew_angle: maximum angle in degrees that a putative line can be skewed before it is removed from consideration - for being too skewed. + Args: + image (np.ndarray): The image to be deskewed. + line_length_prop (float, optional): Proportion of the image's width used to determine line length. Default is 0.5. + max_skew_angle (float, optional): Maximum allowed skew angle for valid lines (in degrees). Default is 10. + + Returns: + np.ndarray: The deskewed image. """ line_length = image.shape[1] * line_length_prop # Flatten image to grayscale for edge detection @@ -95,10 +139,17 @@ def deskew_image_text(self, image: np.ndarray, line_length_prop=0.5, max_skew_an return cv.warpAffine(np.array(image, dtype=np.uint8), rotation_mat, (image.shape[1], image.shape[0])) def split_text_blocks(self, image: np.ndarray, line_length_prop=0.5) -> list[np.ndarray]: - """ - Splits an image with text in it into possibly multiple images, one for each line. + """Splits an image with text in it into (possibly) multiple images, one for each line. + + The function first deskews the image, then uses morphological operations to identify potential lines and words. + It then separates the image into individual text blocks (lines and words). + + Args: + image (np.ndarray): The image to split into text blocks. + line_length_prop (float, optional): Proportion of the image's width used to determine the typical line length. Default is 0.5. - line_length_prop: typical line length as a fraction of the horizontal size of an image. + Returns: + list[np.ndarray]: A list of images representing individual text blocks. """ line_length = image.shape[1] * line_length_prop rotated = self.deskew_image_text(image, line_length_prop) @@ -130,6 +181,19 @@ def split_text_blocks(self, image: np.ndarray, line_length_prop=0.5) -> list[np. return acc def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]: + """Converts image segments into text using Transformers OCR. + + For each segment, it extracts the text and the average confidence score. + + Args: + segments (dict[str, np.ndarray]): A dictionary where keys are segment labels (e.g., 'header', 'body'), + and values are NumPy arrays representing the corresponding image segments. + + Returns: + dict[str, tuple[str, float]]: A dictionary where each key corresponds to a segment label, and each value is + a tuple containing the recognized text (as a string) and the confidence score + (as a float) for the recognition. + """ digitized: dict[str, tuple[str, float]] = {} for label, image in segments.items(): if image is None: @@ -155,6 +219,14 @@ def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, return digitized def calculate_confidence(self, outputs): + """Calculates the confidence level of the OCR output. + + Args: + outputs: The output of the model, containing prediction scores. + + Returns: + float: The confidence percentage of the OCR output. + """ probs = torch.softmax(outputs.scores[0], dim=-1) max_probs = torch.max(probs, dim=-1).values diff --git a/OCR/ocr/services/metrics_analysis.py b/OCR/ocr/services/metrics_analysis.py index a6c7109c..d40c8b85 100644 --- a/OCR/ocr/services/metrics_analysis.py +++ b/OCR/ocr/services/metrics_analysis.py @@ -5,16 +5,12 @@ class OCRMetrics: - """ - A class to calculate and manage OCR metrics. - - """ + """A class to calculate and manage OCR metrics.""" def __init__( self, ocr_json_path=None, ground_truth_json_path=None, ocr_json=None, ground_truth_json=None, testMode=False ): - """ - Parameters: + """Parameters: ocr_json (dict): The JSON data extracted from OCR. ground_truth_json (dict): The JSON data containing ground truth. """ diff --git a/OCR/ocr/services/phdc_converter/builder.py b/OCR/ocr/services/phdc_converter/builder.py index c2aeba85..a78def7f 100644 --- a/OCR/ocr/services/phdc_converter/builder.py +++ b/OCR/ocr/services/phdc_converter/builder.py @@ -14,22 +14,19 @@ class PHDC: - """ - A class to represent a Public Health Data Container (PHDC) document given a + """A class to represent a Public Health Data Container (PHDC) document given a PHDCBuilder. """ def __init__(self, data: ET.ElementTree = None): - """ - Initializes the PHDC class with a PHDCBuilder. + """Initializes the PHDC class with a PHDCBuilder. :param builder: The PHDCBuilder to use to build the PHDC. """ self.data = data def to_xml_string(self) -> bytes: - """ - Return a string representation of the PHDC XML document as serialized bytes. + """Return a string representation of the PHDC XML document as serialized bytes. :return: The PHDC XML document as serialized bytes. """ @@ -43,8 +40,7 @@ def to_xml_string(self) -> bytes: ).decode() def to_element_tree(self) -> ET.ElementTree: - """ - Return the PHDC XML document as an ElementTree. + """Return the PHDC XML document as an ElementTree. :return: The PHDC XML document as an ElementTree. """ @@ -54,30 +50,22 @@ def to_element_tree(self) -> ET.ElementTree: class PHDCBuilder: - """ - A builder class for creating PHDC documents. - """ + """A builder class for creating PHDC documents.""" def __init__(self): - """ - Initializes the PHDCBuilder class and create and empty PHDC. - """ - + """Initializes the PHDCBuilder class and create and empty PHDC.""" self.input_data: PHDCInputData = None self.phdc = self._build_base_phdc() def set_input_data(self, input_data: PHDCInputData): - """ - Given a PHDCInputData object, set the input data for the PHDCBuilder. + """Given a PHDCInputData object, set the input data for the PHDCBuilder. :param input_data: The PHDCInputData object to use as input data. """ - self.input_data = input_data def _build_base_phdc(self) -> ET.ElementTree: - """ - Create the base PHDC XML document. + """Create the base PHDC XML document. :return: The base PHDC XML document. """ @@ -107,8 +95,7 @@ def _build_base_phdc(self) -> ET.ElementTree: return clinical_document def _get_type_id(self) -> ET.Element: - """ - Creates the type ID element of the PHDC header. + """Creates the type ID element of the PHDC header. :return: XML element of . """ @@ -118,8 +105,7 @@ def _get_type_id(self) -> ET.Element: return type_id def _get_id(self) -> ET.Element: - """ - Creates the ID element of the PHDC header. + """Creates the ID element of the PHDC header. :return: XML element of . """ @@ -129,8 +115,7 @@ def _get_id(self) -> ET.Element: return id def _get_effective_time(self) -> ET.Element: - """ - Creates the effectiveTime element of the PHDC header. + """Creates the effectiveTime element of the PHDC header. :return: XML element of . """ @@ -141,8 +126,7 @@ def _get_effective_time(self) -> ET.Element: def _get_confidentiality_code( self, confidentiality: Literal["normal", "restricted", "very restricted"] ) -> ET.Element: - """ - Creates the confidentialityCode element of the PHDC header. + """Creates the confidentialityCode element of the PHDC header. :param confidentiality: The confidentiality code to use. :return: XML element of . @@ -159,19 +143,16 @@ def _get_confidentiality_code( return confidentiality_code def _get_realmCode(self) -> ET.Element: - """ - Creates the realmCode element of the PHDC header. + """Creates the realmCode element of the PHDC header. :return: XML element of . """ - realmCode = ET.Element("realmCode") realmCode.set("code", "US") return realmCode def _get_clinical_info_code(self) -> ET.Element: - """ - Creates the code element of the header for a PHDC case report. + """Creates the code element of the header for a PHDC case report. :return: XML element of . """ @@ -183,8 +164,7 @@ def _get_clinical_info_code(self) -> ET.Element: return code def _get_title(self) -> ET.Element: - """ - Creates the title element of the PHDC header. + """Creates the title element of the PHDC header. :return: XML element of . """ @@ -193,8 +173,7 @@ def _get_title(self) -> ET.Element: return title def _get_setId(self) -> ET.Element: - """ - Creates the setId element of the PHDC header. + """Creates the setId element of the PHDC header. :return: XML element of <setId>. """ @@ -204,8 +183,7 @@ def _get_setId(self) -> ET.Element: return setid def _get_version_number(self) -> ET.Element: - """ - Returns the versionNumber element of the PHDC header. + """Returns the versionNumber element of the PHDC header. :return: XML element of <versionNumber>. """ @@ -217,9 +195,7 @@ def _get_version_number(self) -> ET.Element: return version_number def build_header(self): - """ - Builds the header of the PHDC document. - """ + """Builds the header of the PHDC document.""" root = self.phdc.getroot() root.append(self._get_realmCode()) root.append(self._get_type_id()) @@ -266,8 +242,7 @@ def _add_observations_to_section( section: ET.Element, data: ET.Element, ) -> ET.Element: - """ - Adds Clinical Observation and Social History Information observations to the + """Adds Clinical Observation and Social History Information observations to the appropriate section. :param section: Section XML element. @@ -283,8 +258,7 @@ def _add_observations_to_section( return section def _build_clinical_info(self) -> ET.Element: - """ - Builds the `ClinicalInformation` XML element, including all hardcoded aspects + """Builds the `ClinicalInformation` XML element, including all hardcoded aspects required to initialize the section. :param observation_data: List of clinical-relevant Observation data. @@ -316,8 +290,7 @@ def _build_clinical_info(self) -> ET.Element: return component def _build_social_history_info(self) -> ET.Element: - """ - Builds the Social History Information XML section, including all hardcoded + """Builds the Social History Information XML section, including all hardcoded aspects required to initialize the section. :return: XML element of SocialHistory data. """ @@ -351,8 +324,7 @@ def _build_social_history_info(self) -> ET.Element: return component def _build_repeating_questions(self) -> ET.Element: - """ - Builds the Repeating Questions XML section, including all hardcoded + """Builds the Repeating Questions XML section, including all hardcoded aspects required to initialize the section. :return: XML element of Repeating Questions data. """ @@ -413,8 +385,7 @@ def _build_repeating_questions(self) -> ET.Element: return component_section def _build_telecom(self, telecom: Telecom) -> ET.Element: - """ - Builds a `telecom` XML element for phone data including phone number (as + """Builds a `telecom` XML element for phone data including phone number (as `value`) and use, if available. There are three types of phone uses: 'HP' for home phone, 'WP' for work phone, and 'MC' for mobile phone. @@ -442,8 +413,7 @@ def _build_telecom(self, telecom: Telecom) -> ET.Element: return telecom_data def _add_field(self, parent_element: ET.Element, data: str, field_name: str): - """ - Adds a child element to a parent element given the data and field name. + """Adds a child element to a parent element given the data and field name. :param parent_element: The parent element to add the child element to. :param data: The data to add to the child element. @@ -455,8 +425,7 @@ def _add_field(self, parent_element: ET.Element, data: str, field_name: str): parent_element.append(e) def _build_observation(self, observation: Observation) -> ET.Element: - """ - Creates Entry XML element for observation data. + """Creates Entry XML element for observation data. :param observation: The data for building the observation element as an Entry object. @@ -496,8 +465,7 @@ def _build_observation(self, observation: Observation) -> ET.Element: return observation_data def _set_value_xsi_type(self, observation: Observation) -> Observation: - """ - Ensure that observation elements with a value child element use + """Ensure that observation elements with a value child element use the correct namespace based on the data. :param observation: The observation data being used in _build_observation @@ -540,8 +508,7 @@ def _build_addr( self, address: Address, ) -> ET.Element: - """ - Builds an `addr` XML element for address data. There are two types of address + """Builds an `addr` XML element for address data. There are two types of address uses: 'H' for home address and 'WP' for workplace address. :param address: The data for building the address element as an Address object. @@ -563,13 +530,11 @@ def _build_addr( return address_data def _build_name(self, name: Name) -> ET.Element: - """ - Builds a `name` XML element for name data. + """Builds a `name` XML element for name data. :param name: The data for constructing the name element as a Name object. :return: XML element of name data. """ - name_data = ET.Element("name") if name.type is not None: @@ -591,8 +556,7 @@ def _build_name(self, name: Name) -> ET.Element: return name_data def _build_patient(self, patient: Patient) -> ET.Element: - """ - Given a Patient object, build the patient element of the PHDC. + """Given a Patient object, build the patient element of the PHDC. :param patient: The Patient object to use for building the patient element. :return: XML element of patient data. @@ -671,8 +635,7 @@ def _build_recordTarget( address_data: Optional[List[Address]] = None, patient_data: Optional[Patient] = None, ) -> ET.Element: - """ - Builds a `recordTarget` XML element for recordTarget data, which refers to + """Builds a `recordTarget` XML element for recordTarget data, which refers to the medical record of the patient. :param id: recordTarget identifier @@ -731,8 +694,7 @@ def _build_recordTarget( return recordTarget_data def build(self) -> PHDC: - """ - Constructs a PHDC document by building the header and body components. + """Constructs a PHDC document by building the header and body components. :return: A PHDC document as an instance of the PHDC class. """ diff --git a/OCR/ocr/services/phdc_converter/models.py b/OCR/ocr/services/phdc_converter/models.py index 792c7855..4069d956 100644 --- a/OCR/ocr/services/phdc_converter/models.py +++ b/OCR/ocr/services/phdc_converter/models.py @@ -9,9 +9,7 @@ @dataclass class Telecom: - """ - A class containing all of the data elements for a telecom element. - """ + """A class containing all of the data elements for a telecom element.""" value: Optional[str] = None type: Optional[str] = None @@ -21,9 +19,7 @@ class Telecom: @dataclass class Address: - """ - A class containing all of the data elements for an address element. - """ + """A class containing all of the data elements for an address element.""" street_address_line_1: Optional[str] = None street_address_line_2: Optional[str] = None @@ -39,9 +35,7 @@ class Address: @dataclass class Name: - """ - A class containing all of the data elements for a name element. - """ + """A class containing all of the data elements for a name element.""" prefix: Optional[str] = None first: Optional[str] = None @@ -55,9 +49,7 @@ class Name: @dataclass class Patient: - """ - A class containing all of the data elements for a patient element. - """ + """A class containing all of the data elements for a patient element.""" name: List[Name] = None address: List[Address] = None @@ -70,9 +62,7 @@ class Patient: @dataclass class Organization: - """ - A class containing all of the data elements for an organization element. - """ + """A class containing all of the data elements for an organization element.""" id: str = None name: str = None @@ -82,9 +72,7 @@ class Organization: @dataclass class CodedElement: - """ - A class containing all of the data elements for a coded element. - """ + """A class containing all of the data elements for a coded element.""" xsi_type: Optional[str] = None code: Optional[str] = None @@ -95,8 +83,7 @@ class CodedElement: text: Optional[Union[str, int]] = None def to_attributes(self) -> Dict[str, str]: - """ - Given a standard CodedElements return a dictionary that can be iterated over to + """Given a standard CodedElements return a dictionary that can be iterated over to produce the corresponding XML element. :return: A dictionary of the CodedElement's attributes @@ -116,9 +103,7 @@ def to_attributes(self) -> Dict[str, str]: @dataclass class Observation: - """ - A class containing all of the data elements for an observation element. - """ + """A class containing all of the data elements for an observation element.""" obs_type: str = "laboratory" type_code: Optional[str] = None @@ -154,8 +139,7 @@ class Observation: @dataclass class PHDCInputData: - """ - A class containing all of the data to construct a PHDC document when passed to the + """A class containing all of the data to construct a PHDC document when passed to the PHDCBuilder. """ diff --git a/OCR/ocr/services/phdc_converter/phdc_converter.py b/OCR/ocr/services/phdc_converter/phdc_converter.py index 7a51d7b4..8d92735d 100644 --- a/OCR/ocr/services/phdc_converter/phdc_converter.py +++ b/OCR/ocr/services/phdc_converter/phdc_converter.py @@ -3,9 +3,7 @@ class PHDCConverter: - """ - Parse the OCR data converted to json to create an instance of the Patient data class. - """ + """Parse the OCR data converted to json to create an instance of the Patient data class.""" def parse_patient_data(self, json_data): name = Name(first=json_data.get("patient_first_name", ""), family=json_data.get("patient_last_name", "")) @@ -31,9 +29,7 @@ def parse_patient_data(self, json_data): return patient def generate_phdc_document(self, json_data): - """ - Generate the PHDC document using parsed OCR data. - """ + """Generate the PHDC document using parsed OCR data.""" patient = self.parse_patient_data(json_data) phdc_input = PHDCInputData(patient=patient, type="case_report") diff --git a/OCR/ocr/services/tesseract_ocr.py b/OCR/ocr/services/tesseract_ocr.py index c796a065..03872a77 100644 --- a/OCR/ocr/services/tesseract_ocr.py +++ b/OCR/ocr/services/tesseract_ocr.py @@ -1,3 +1,5 @@ +"""Module for OCR services using a Tesseract backend.""" + import os import tesserocr @@ -7,30 +9,45 @@ class TesseractOCR: + """A class to provide OCR services using Tesseract as the backend. + + This class supports configuring Tesseract's page segmentation modes and customizing its behavior + through internal variables. + + Attributes: + psm (int): The page segmentation mode for Tesseract, specifying how Tesseract interprets the structure of the document. + variables (dict): A dictionary of variables to customize Tesseract's behavior. + + See Also: + * https://github.com/sirfz/tesserocr/blob/bbe0fb8edabdcc990f1e6fa9334c0747c2ac76ee/tesserocr/__init__.pyi#L47 + * https://tesseract-ocr.github.io/tessdoc/tess3/ControlParams.html + """ def __init__(self, psm=PSM.AUTO, variables=dict()): - """ - Initialize the tesseract OCR model. + """Initializes the TesseractOCR object with the specified page segmentation mode and internal variables. - `psm` (int): an enum (from `PSM`) that defines tesseract's page segmentation mode. Default is `AUTO`. - `variables` (dict): a dict to customize tesseract's behavior with internal variables + Args: + psm (int, optional): The page segmentation mode (from `tesserocr.PSM`). Default is `PSM.AUTO`. + variables (dict, optional): A dictionary of variables to customize Tesseract's behavior. Default is an empty dictionary. """ self.psm = psm self.variables = variables @staticmethod def _guess_tessdata_path(wanted_lang="eng") -> bytes: - """ - Attempts to guess potential locations for the `tessdata` folder. + """Attempts to guess potential locations for the `tessdata` folder. The `tessdata` folder is needed to use pre-trained Tesseract OCR data, though the automatic detection - provided in `tesserocr` may not be reliable. Instead iterate over common paths on various systems (e.g., - Red Hat, Ubuntu, macOS) and check for the presence of a `tessdata` folder. + provided in `tesserocr` may not be reliable. - If `TESSDATA_PREFIX` is available in the environment, this function will check that location first. - If all guessed locations do not exist, fall back to automatic detection provided by `tesserocr` and - the tesseract API. + The function first checks the path defined by the environment variable `TESSDATA_PREFIX` (if available), + and then falls back to searching several default candidate paths on various systems (e.g., Red Hat, Ubuntu, + macOS). If no valid path is found, it uses the automatic detection provided by the Tesseract API, which may fail. - `wanted_lang` (str): a desired language to search for. Defaults to English `eng`. + Args: + wanted_lang (str, optional): The desired language to search for in the `tessdata` folder. Default is 'eng' (English). + + Returns: + bytes: The path to the `tessdata` directory containing the OCR language files. """ candidate_paths = [ "/usr/local/share/tesseract/tessdata", @@ -63,6 +80,22 @@ def _guess_tessdata_path(wanted_lang="eng") -> bytes: return tesserocr.get_languages()[0] def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]: + """Converts image segments into text using Tesseract OCR. + + The function processes a dictionary of image segments, where each key corresponds to a segment label, + and each value is a NumPy array representing an image segment. + + For each segment, it extracts the text and the average confidence score returned from the Tesseract API. + + Args: + segments (dict[str, np.ndarray]): A dictionary where keys are segment labels (e.g., 'header', 'body'), + and values are NumPy arrays representing the corresponding image segments. + + Returns: + dict[str, tuple[str, float]]: A dictionary where each key corresponds to a segment label, and each value is + a tuple containing the recognized text (as a string) and the confidence score + (as a float) for the recognition. + """ digitized: dict[str, tuple[str, float]] = {} with tesserocr.PyTessBaseAPI(psm=self.psm, variables=self.variables, path=self._guess_tessdata_path()) as api: for label, image in segments.items(): diff --git a/OCR/pyproject.toml b/OCR/pyproject.toml index 1380120a..a79782c4 100644 --- a/OCR/pyproject.toml +++ b/OCR/pyproject.toml @@ -38,3 +38,14 @@ build = "ocr.pyinstaller:install" [tool.ruff] line-length = 118 target-version = "py310" + +[tool.ruff.lint] +select = ["D"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +# Ignore test directories and init.py +"tests/**" = ["D"] +"__init__.py" = ["D"]