From 612d5cc2933cb93fc79efb27be70f0701947f14d Mon Sep 17 00:00:00 2001 From: Jonathan Perdomo Date: Thu, 25 Jul 2024 13:16:53 -0400 Subject: [PATCH] 52 base modification tags (#53) * Add base modification (methylation) QC * Add POD5 + dorado basecalled BAM QC --- .github/workflows/build-test.yml | 2 +- .gitignore | 2 + Makefile | 12 +- include/hts_reader.h | 11 + include/input_parameters.h | 4 +- include/output_data.h | 194 ++++++++++++------ include/ref_query.h | 52 +++++ src/bam_module.cpp | 122 ++++++++++- src/cli.py | 86 ++++++-- src/fast5_module.cpp | 9 +- src/generate_html.py | 35 ++-- src/hts_reader.cpp | 335 +++++++++++++++++++++++-------- src/input_parameters.cpp | 24 ++- src/lrst.i | 64 ++++++ src/output_data.cpp | 161 +++++++++++++-- src/plot_utils.py | 175 +++++++++++++--- src/pod5_module.py | 2 +- src/ref_query.cpp | 255 +++++++++++++++++++++++ tests/test_general.py | 163 ++++++++++++++- 19 files changed, 1454 insertions(+), 254 deletions(-) create mode 100644 include/ref_query.h create mode 100644 src/ref_query.cpp diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f7bdff3..1c87ca8 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -18,7 +18,7 @@ jobs: uses: dsaltares/fetch-gh-release-asset@1.0.0 with: repo: 'WGLab/LongReadSum' - version: 'tags/v0.1.0' + version: 'tags/v1.3.1' file: 'SampleData.zip' - name: Unzip assets diff --git a/.gitignore b/.gitignore index 6b4a66d..ad4fc6a 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ SampleData # Testing scripts linktoscripts +single_mod.summary.txt +single_mod.tin.xls diff --git a/Makefile b/Makefile index 271d6ba..ec13daf 100644 --- a/Makefile +++ b/Makefile @@ -2,9 +2,13 @@ INCL_DIR := $(CURDIR)/include SRC_DIR := $(CURDIR)/src LIB_DIR := $(CURDIR)/lib -all: - # Generate the SWIG Python/C++ wrappers +# All targets +all: swig_build compile + +# Generate the SWIG Python/C++ wrappers +swig_build: swig -c++ -python -outdir $(LIB_DIR) -I$(INCL_DIR) -o $(SRC_DIR)/lrst_wrap.cpp $(SRC_DIR)/lrst.i - # Compile the C++ shared libraries into lib/ - python setup.py build_ext --build-lib $(LIB_DIR) +# Compile the C++ shared libraries into lib/ +compile: + python3 setup.py build_ext --build-lib $(LIB_DIR) diff --git a/include/hts_reader.h b/include/hts_reader.h index 411e2a2..4704d28 100644 --- a/include/hts_reader.h +++ b/include/hts_reader.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "output_data.h" @@ -27,6 +28,11 @@ class HTSReader { bam_hdr_t* header; // read the BAM header bam1_t* record; int record_count = 0; + + // Atomic flags for whether certain BAM flags are present + std::atomic_flag has_nm_tag = ATOMIC_FLAG_INIT; // NM tag for number of mismatches using edit distance + std::atomic_flag has_mm_ml_tags = ATOMIC_FLAG_INIT; // MM and ML tags for modified base information + std::atomic_flag has_pod5_tags = ATOMIC_FLAG_INIT; // POD5 tags for signal-level information (ts, ns) // Bool for whether the reading is complete bool reading_complete = false; @@ -43,6 +49,11 @@ class HTSReader { // Return the number of records in the BAM file using the BAM index int64_t getNumRecords(const std::string &bam_file_name); + std::map getQueryToRefMap(bam1_t *record); + + // Add a modification to the base modification map + void addModificationToQueryMap(std::map> &base_modifications, int32_t pos, char mod_type, char canonical_base, double likelihood, int strand); + HTSReader(const std::string &bam_file_name); ~HTSReader(); }; diff --git a/include/input_parameters.h b/include/input_parameters.h index 614ee0d..b52571d 100644 --- a/include/input_parameters.h +++ b/include/input_parameters.h @@ -32,9 +32,11 @@ class Input_Para{ std::string rrms_csv; // CSV file with accepted/rejected read IDs (RRMS module) bool rrms_filter; // Generate RRMS stats for accepted (true) or rejected (false) reads std::unordered_set rrms_read_ids; // List of read IDs from RRMS CSV file (accepted or rejected) + std::string ref_genome; // Reference genome file for BAM base modification analysis + double base_mod_threshold; // Base modification threshold for BAM base modification analysis // Functions - std::string add_input_file(const std::string& _ip_file); + std::string add_input_file(const std::string& input_filepath); Input_Para(); diff --git a/include/output_data.h b/include/output_data.h index 1bfd0a4..7e15f6d 100644 --- a/include/output_data.h +++ b/include/output_data.h @@ -9,6 +9,7 @@ Define the output structures for each module. #include #include #include +#include #include #include "input_parameters.h" @@ -110,58 +111,142 @@ class Output_FQ : public Output_FA }; +// Define the base modification data structure (modification type, canonical +// base, likelihood, strand: 0 for forward, 1 for reverse, and CpG flag: T/F) +using Base_Modification = std::tuple; + +// Define the signal-level data structure for POD5 (ts, ns, move table vector) +using POD5_Signal_Data = std::tuple>; + +// Base class for storing a read's base signal data +class Base_Signals +{ +public: + std::string read_name; + int base_count; + std::string sequence_data_str; // Sequence of bases + std::vector> basecall_signals; // 2D vector of base signals + + // Methods + int getBaseCount(); + std::string getReadName(); + std::string getSequenceString(); + std::vector> getDataVector(); + Base_Signals(std::string read_name, std::string sequence_data_str, std::vector> basecall_signals); +}; + +// Base class for storing a read's sequence and move table (basecalled POD5 in +// BAM format) +class Base_Move_Table +{ +public: + std::string sequence_data_str; // Sequence of bases + std::vector base_signal_index; // 2D vector of signal indices for each base + int sequence_start; // Signal index of the first base (ts) + int sequence_end; // Signal index of the last base (ns) + + // Methods + std::string getSequenceString(); + std::vector getBaseSignalIndex(); + int getSequenceStart(); + int getSequenceEnd(); + Base_Move_Table(std::string sequence_data_str, std::vector base_signal_index, int start, int end); + Base_Move_Table(); +}; + + // BAM output class Output_BAM : public Output_FQ { public: - uint64_t num_primary_alignment = ZeroDefault; // the number of primary alignment/ - uint64_t num_secondary_alignment = ZeroDefault; // the number of secondary alignment - uint64_t num_reads_with_secondary_alignment = ZeroDefault; // the number of long reads with the secondary alignment: one read might have multiple seconard alignment - uint64_t num_supplementary_alignment = ZeroDefault; // the number of supplementary alignment - uint64_t num_reads_with_supplementary_alignment = ZeroDefault; // the number of long reads with secondary alignment; - uint64_t num_reads_with_both_secondary_supplementary_alignment = ZeroDefault; // the number of long reads with both secondary and supplementary alignment. - uint64_t forward_alignment = ZeroDefault; // Total number of forward alignments - uint64_t reverse_alignment = ZeroDefault; // Total number of reverse alignments - int reads_with_mods = ZeroDefault; // Total number of reads with modification tags - int reads_with_mods_pos_strand = ZeroDefault; // Total number of reads with modification tags on the positive strand - int reads_with_mods_neg_strand = ZeroDefault; // Total number of reads with modification tags on the negative strand - - // Map of reads with supplementary alignments - std::map reads_with_supplementary; - - // Map of reads with secondary alignments - std::map reads_with_secondary; - - // Similar to Output_FA: below are for mapped. - uint64_t num_matched_bases = ZeroDefault; // the number of matched bases with = - uint64_t num_mismatched_bases = ZeroDefault; // the number of mismatched bases X - uint64_t num_ins_bases = ZeroDefault; // the number of inserted bases; - uint64_t num_del_bases = ZeroDefault; // the number of deleted bases; - uint64_t num_clip_bases = ZeroDefault; // the number of soft-clipped bases; - - // The number of columns can be calculated by summing over the lengths of M/I/D CIGAR operators - int num_columns = ZeroDefault; // the number of columns - double percent_identity = ZeroDefault; // Percent identity = (num columns - NM) / num columns - - std::vector accuracy_per_read; - - Basic_Seq_Statistics mapped_long_read_info; - Basic_Seq_Statistics unmapped_long_read_info; - - Basic_Seq_Quality_Statistics mapped_seq_quality_info; - Basic_Seq_Quality_Statistics unmapped_seq_quality_info; - - // Add a batch of records to the output - void add(Output_BAM &t_output_bam); - - // Calculate QC across all records - void global_sum(); + uint64_t num_primary_alignment = ZeroDefault; // the number of primary alignment/ + uint64_t num_secondary_alignment = ZeroDefault; // the number of secondary alignment + uint64_t num_reads_with_secondary_alignment = ZeroDefault; // the number of long reads with the secondary alignment: one read might have multiple seconard alignment + uint64_t num_supplementary_alignment = ZeroDefault; // the number of supplementary alignment + uint64_t num_reads_with_supplementary_alignment = ZeroDefault; // the number of long reads with secondary alignment; + uint64_t num_reads_with_both_secondary_supplementary_alignment = ZeroDefault; // the number of long reads with both secondary and supplementary alignment. + uint64_t forward_alignment = ZeroDefault; // Total number of forward alignments + uint64_t reverse_alignment = ZeroDefault; // Total number of reverse alignments + std::map reads_with_supplementary; // Map of reads with supplementary alignments + std::map reads_with_secondary; // Map of reads with secondary alignments + + // Similar to Output_FA: below are for mapped. + uint64_t num_matched_bases = ZeroDefault; // the number of matched bases with = + uint64_t num_mismatched_bases = ZeroDefault; // the number of mismatched bases X + uint64_t num_ins_bases = ZeroDefault; // the number of inserted bases; + uint64_t num_del_bases = ZeroDefault; // the number of deleted bases; + uint64_t num_clip_bases = ZeroDefault; // the number of soft-clipped bases; + + // The number of columns can be calculated by summing over the lengths of M/I/D CIGAR operators + int num_columns = ZeroDefault; // the number of columns + double percent_identity = ZeroDefault; // Percent identity = (num columns - NM) / num columns + std::vector accuracy_per_read; + + // Number of modified bases by position in the reference: + // chr -> reference position -> (modification type, canonical base, maximum + // likelihood, strand) + std::map> base_modifications; + uint64_t modified_prediction_count = ZeroDefault; // Total number of modified base predictions + uint64_t modified_base_count = ZeroDefault; // Total number of modified bases mapped to the reference genome + uint64_t modified_base_count_forward = ZeroDefault; // Total number of modified bases in the genome on the forward strand + uint64_t modified_base_count_reverse = ZeroDefault; // Total number of modified bases in the genome on the reverse strand + // uint64_t c_modified_base_count = ZeroDefault; // Total C modified bases + uint64_t cpg_modified_base_count = ZeroDefault; // Total C modified bases in CpG sites (combined forward and reverse) + uint64_t cpg_modified_base_count_forward = ZeroDefault; // Total C modified bases in CpG sites on the forward strand + uint64_t cpg_modified_base_count_reverse = ZeroDefault; // Total C modified bases in CpG sites on the reverse strand + uint64_t cpg_genome_count = ZeroDefault; // Total number of CpG sites in the genome + double percent_modified_cpg = ZeroDefault; // Percentage of CpG sites with modified bases (forward and reverse) + + // Test counts + uint64_t test_count = ZeroDefault; // Test count + uint64_t test_count2 = ZeroDefault; // Test count 2 + + // Counts for each type of modification: + // Modification type -> count + std::map modification_type_counts; + + // Signal data section + int read_count = ZeroDefault; + int base_count = ZeroDefault; + // std::vector read_move_table; + std::unordered_map read_move_table; + // std::map read_move_table = {}; - // Save the output to a summary text file - void save_summary(std::string &output_file, Input_Para ¶ms, Output_BAM &output_data); + // POD5 signal-level information is stored in a map of read names to a map of + // reference positions to a tuple of (ts, ns, move table vector) + std::unordered_map pod5_signal_data; - Output_BAM(); - ~Output_BAM(); + Basic_Seq_Statistics mapped_long_read_info; + Basic_Seq_Statistics unmapped_long_read_info; + + Basic_Seq_Quality_Statistics mapped_seq_quality_info; + Basic_Seq_Quality_Statistics unmapped_seq_quality_info; + + // Add modified base data + void add_modification(std::string chr, int32_t ref_pos, char mod_type, char canonical_base, double likelihood, int strand); + + // Return the modification information + std::map> get_modifications(); + + // POD5 signal data functions + int getReadCount(); + void addReadMoveTable(std::string read_name, std::string sequence_data_str, std::vector move_table, int start, int end); + std::vector getReadMoveTable(std::string read_id); + std::string getReadSequence(std::string read_id); + int getReadSequenceStart(std::string read_id); + int getReadSequenceEnd(std::string read_id); + + // Add a batch of records to the output + void add(Output_BAM &t_output_bam); + + // Calculate QC across all records + void global_sum(); + + // Save the output to a summary text file + void save_summary(std::string &output_file, Input_Para ¶ms, Output_BAM &output_data); + + Output_BAM(); + ~Output_BAM(); }; @@ -193,23 +278,6 @@ class Output_SeqTxt : public Output_Info void global_sum(); }; -// Base class for storing a read's base signal data -class Base_Signals -{ -public: - std::string read_name; - int base_count; - std::string sequence_data_str; // Sequence of bases - std::vector> basecall_signals; // 2D vector of base signals - - // Methods - int getBaseCount(); - std::string getReadName(); - std::string getSequenceString(); - std::vector> getDataVector(); - Base_Signals(std::string read_name, std::string sequence_data_str, std::vector> basecall_signals); -}; - // FAST5 output class Output_FAST5 : public Output_FA { diff --git a/include/ref_query.h b/include/ref_query.h new file mode 100644 index 0000000..96f44d3 --- /dev/null +++ b/include/ref_query.h @@ -0,0 +1,52 @@ +// RefQuery: A class for querying a reference genome in FASTA format + +#ifndef REF_QUERY_H +#define REF_QUERY_H + +#include +#include +#include +#include +#include + +class RefQuery { + private: + std::string fasta_filepath; + std::vector chromosomes; + std::unordered_map chr_to_seq; +// uint32_t cpg_site_count = 0; + uint64_t cpg_modified_count = 0; + uint64_t cpg_total_count = 0; + uint64_t test_count = 0; + + // Map of reference position (0-indexed) to CpG site (true/false) on the + // forward strand for each chromosome + // std::map> chr_pos_to_cpg; + + // Map of reference position (0-indexed) to CpG site (true/false) on the + // reverse strand for each chromosome + // std::map> chr_pos_to_cpg_rev; + + // Map of reference position (0-indexed) to CpG site (true/false) for + // all chromosomes +// std::map> chr_pos_to_cpg; + + // Map of chromosome to CpG site positions + std::unordered_map> chr_to_cpg; + + // Map of chromosome to CpG site positions with modifications + std::unordered_map> chr_to_cpg_mod; + + // Reverse strand CpG site map + // std::map> chr_pos_to_cpg_rev; + + public: + int setFilepath(std::string fasta_filepath); + std::string getFilepath(); + std::string getBase(std::string chr, int64_t pos); + void generateCpGMap(); + void addCpGSiteModification(std::string chr, int64_t pos, int strand); + std::pair getCpGModificationCounts(int strand); +}; + +#endif // REF_QUERY_H diff --git a/src/bam_module.cpp b/src/bam_module.cpp index 39110df..810c9d6 100644 --- a/src/bam_module.cpp +++ b/src/bam_module.cpp @@ -10,9 +10,13 @@ Class for generating BAM file statistics. Records are accessed using multi-threa #include #include #include +#include #include "bam_module.h" +#include "utils.h" +#include "ref_query.h" // For reference genome analysis + // Run the BAM module int BAM_Module::run(Input_Para &input_params, Output_BAM &final_output) { @@ -65,7 +69,7 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ // Create a BAM reader std::string filepath(input_params.input_files[this->file_index]); HTSReader reader(filepath); - std::cout<<"Processing file: "<< filepath << std::endl; + std::cout << "Processing file: "<< filepath << std::endl; // Get the number of threads (Set to 1 if the number of threads is not specified/invalid) int thread_count = input_params.threads; @@ -112,9 +116,6 @@ int BAM_Module::calculateStatistics(Input_Para &input_params, Output_BAM &final_ std::cout << "Generating " << thread_count << " thread(s)..." << std::endl; std::vector thread_vector; for (int thread_index=0; thread_index 0){ + + // Determine CpG modification rate + // First, read the reference genome + if (input_params.ref_genome != ""){ + std::cout << "Reading reference genome for CpG modification rate calculation..." << std::endl; + RefQuery ref_query; + ref_query.setFilepath(input_params.ref_genome); + std::cout << "Reference genome read." << std::endl; + + // Loop through the base modifications and find the CpG + // modifications + double base_mod_threshold = input_params.base_mod_threshold; + for (auto const &it : final_output.base_modifications) { + std::string chr = it.first; + std::map base_mods = it.second; + + // Loop through the base modifications + for (auto const &it2 : base_mods) { + + // Get the base modification information + //int32_t ref_pos = it2.first; + int64_t ref_pos = (int64_t) it2.first; + Base_Modification mod = it2.second; + char mod_type = std::get<0>(mod); + char canonical_base_char = std::toupper(std::get<1>(mod)); + std::string canonical_base(1, canonical_base_char); + + double probability = std::get<2>(mod); + int strand = std::get<3>(mod); + if (probability >= base_mod_threshold) + { + // Update the modified base count + final_output.modified_base_count++; + + // Update the strand-specific modified base counts + if (strand == 0) { + final_output.modified_base_count_forward++; + } else if (strand == 1) { + final_output.modified_base_count_reverse++; + } + + // Get CpG modification information for cytosines + std::string ref_base = ref_query.getBase(chr, ref_pos); + if (canonical_base == "C") { + + if ((ref_base == "C") && (mod_type == 'm') && (strand == 0)) + { + + // Determine if it resides in a CpG site + std::string next_base = ref_query.getBase(chr, ref_pos + 1); + if (next_base == "G") { + + // Update the strand-specific CpG modified base + // count + final_output.cpg_modified_base_count_forward++; + + // Update the CpG modification flag + std::get<4>(final_output.base_modifications[chr][ref_pos]) = true; + + // Add the CpG site modification + ref_query.addCpGSiteModification(chr, ref_pos, strand); + } + + } else if ((ref_base == "G") && (mod_type == 'm') && (strand == 1)) { + + // Determine if it resides in a CpG site + std::string previous_base = ref_query.getBase(chr, ref_pos - 1); + + if (previous_base == "C") + { + // Update the strand-specific CpG modified base + // count + final_output.cpg_modified_base_count_reverse++; + + // Update the CpG modification flag + std::get<4>(final_output.base_modifications[chr][ref_pos]) = true; + + // Add the CpG site modification + ref_query.addCpGSiteModification(chr, ref_pos, strand); + } + } + } + } + } + } + + // Calculate CpG site statistics + if (final_output.cpg_modified_base_count_forward > 0 || final_output.cpg_modified_base_count_reverse > 0) + { + // Calculate the number of CpG sites with modifications + std::pair cpg_mod_counts = ref_query.getCpGModificationCounts(0); + final_output.cpg_modified_base_count = cpg_mod_counts.first; + final_output.cpg_genome_count = cpg_mod_counts.first + cpg_mod_counts.second; + + // Calculate the CpG modification rate + double cpg_mod_rate = (double)cpg_mod_counts.first / (double)(cpg_mod_counts.first + cpg_mod_counts.second); + final_output.percent_modified_cpg = cpg_mod_rate * 100; + } + + std::cout << "Number of CpG modified bases: " << final_output.cpg_modified_base_count << std::endl; + std::cout << "Total number of modified bases: " << final_output.modified_base_count << std::endl; + } + } + // Calculate the global sums across all records std::cout << "Calculating summary QC..." << std::endl; final_output.global_sum(); @@ -236,15 +343,8 @@ std::unordered_set BAM_Module::readRRMSFile(std::string rrms_csv_fi } } - // Close the file rrms_file.close(); - // // Print the first 10 read IDs - // std::cout << "First 10 read IDs:" << std::endl; - // for (int i=0; i<10; i++){ - // std::cout << rrms_read_ids[i] << std::endl; - // } - return rrms_read_ids; } \ No newline at end of file diff --git a/src/cli.py b/src/cli.py index 8f637ae..3592914 100644 --- a/src/cli.py +++ b/src/cli.py @@ -168,7 +168,7 @@ def fq_module(margs): fq_html_gen = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "base_quality", "read_avg_base_quality"], "FASTQ QC", param_dict], plot_filepaths, static=False) - fq_html_gen.generate_st_html() + fq_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -208,13 +208,14 @@ def fa_module(margs): fa_html_gen = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "base_counts"], "FASTA QC", param_dict], plot_filepaths, static=True) - fa_html_gen.generate_st_html() + fa_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: logging.error("QC did not generate.") def bam_module(margs): + """BAM file input module.""" # Get the filetype-specific parameters param_dict = get_common_param(margs) if param_dict == {}: @@ -231,8 +232,18 @@ def bam_module(margs): input_para.other_flags = (1 if param_dict["detail"] > 0 else 0); input_para.output_folder = str(param_dict["output_folder"]) input_para.out_prefix = str(param_dict["out_prefix"]) - for _ipf in param_dict["input_files"]: - input_para.add_input_file(str(_ipf)) + + # Set the reference genome file and base modification threshold + param_dict["ref"] = margs.ref if margs.ref != "" or margs.ref is not None else "" + param_dict["modprob"] = margs.modprob + + logging.info("Reference genome file is " + param_dict["ref"]) + input_para.ref_genome = param_dict["ref"] + logging.info("Updated ref genome file is " + input_para.ref_genome) + + input_para.base_mod_threshold = margs.modprob + for input_file in param_dict["input_files"]: + input_para.add_input_file(str(input_file)) bam_output = lrst.Output_BAM() exit_code = lrst.callBAMModule(input_para, bam_output) @@ -241,11 +252,18 @@ def bam_module(margs): logging.info("Generating HTML report...") plot_filepaths = plot(bam_output, param_dict, 'BAM') - # TODO: Add read average base quality plot (not currently generated by bam_plot.plot) + # Set the list of QC information to display + qc_info_list = ["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality"] + + # If base modifications were found, add the base modification plots + # after the first table + if bam_output.modified_base_count > 0: + qc_info_list.insert(1, "base_mods") + + # If base modifications were found, add the base modification plots bam_html_gen = generate_html.ST_HTML_Generator( - [["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "base_counts", "basic_info", - "base_quality"], "BAM QC", param_dict], plot_filepaths, static=False) - bam_html_gen.generate_st_html() + [qc_info_list, "BAM QC", param_dict], plot_filepaths, static=False) + bam_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -300,7 +318,7 @@ def rrms_module(margs): bam_html_gen = generate_html.ST_HTML_Generator( [["basic_st", "read_alignments_bar", "base_alignments_bar", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality"], "BAM QC", param_dict], plot_filepaths, static=False) - bam_html_gen.generate_st_html() + bam_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -348,7 +366,7 @@ def seqtxt_module(margs): seqtxt_html_gen = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "basic_info"], "sequencing_summary.txt QC", param_dict], plot_filepaths, static=False) - seqtxt_html_gen.generate_st_html() + seqtxt_html_gen.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: logging.error("QC did not generate.") @@ -384,7 +402,7 @@ def fast5_module(margs): fast5_html_obj = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", "read_avg_base_quality"], "FAST5 QC", param_dict], plot_filepaths, static=False) - fast5_html_obj.generate_st_html() + fast5_html_obj.generate_html() logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -428,7 +446,7 @@ def fast5_signal_module(margs): fast5_html_obj = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", "read_avg_base_quality", "ont_signal"], "FAST5 QC", param_dict], plot_filepaths, static=False) - fast5_html_obj.generate_st_html(signal_plots=True) + fast5_html_obj.generate_html(signal_plots=True) logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -464,17 +482,41 @@ def pod5_module(margs): else: input_para['read_ids'] = "" + # Get the basecalled BAM file if specified, and run the BAM module + basecall_data = False + bam_output = None + basecalls = margs.basecalls + if basecalls != "" and basecalls is not None: + basecalls_input = lrst.Input_Para() + basecalls_input.threads = param_dict["threads"] + basecalls_input.rdm_seed = param_dict["random_seed"] + basecalls_input.downsample_percentage = param_dict["downsample_percentage"] + basecalls_input.other_flags = (1 if param_dict["detail"] > 0 else 0) + basecalls_input.output_folder = str(param_dict["output_folder"]) + basecalls_input.out_prefix = str(param_dict["out_prefix"]) + basecalls_input.add_input_file(basecalls) + bam_output = lrst.Output_BAM() + exit_code = lrst.callBAMModule(basecalls_input, bam_output) + if exit_code == 0: + basecall_data = True + logging.info("Basecalled BAM QC generated.") + read_signal_dict = generate_pod5_qc(input_para) if read_signal_dict is not None: logging.info("QC generated.") logging.info("Generating HTML report...") - plot_filepaths = plot_pod5(read_signal_dict, param_dict) + + if basecall_data: + plot_filepaths = plot_pod5(read_signal_dict, param_dict, bam_output) + else: + plot_filepaths = plot(read_signal_dict, param_dict, None) + # plot_filepaths = plot(read_signal_dict, param_dict, 'POD5') webpage_title = "POD5 QC" fast5_html_obj = generate_html.ST_HTML_Generator( [["basic_st", "read_length_bar", "read_length_hist", "base_counts", "basic_info", "base_quality", "read_avg_base_quality", "ont_signal"], webpage_title, param_dict], plot_filepaths, static=False) - fast5_html_obj.generate_st_html(signal_plots=True) + fast5_html_obj.generate_html(signal_plots=True) logging.info("Done. Output files are in %s", param_dict["output_folder"]) else: @@ -598,6 +640,10 @@ def pod5_module(margs): # Add an argument for specifying the read names to extract pod5_parser.add_argument("-r", "--read_ids", type=str, default=None, help="A comma-separated list of read IDs to extract from the file.") + +# Add an argument for specifying the basecalled BAM file +pod5_parser.add_argument("-b", "--basecalls", type=str, default=None, + help="The basecalled BAM file to use for signal extraction.") # Sequencing summary text file input seqtxt_parser = subparsers.add_parser('seqtxt', @@ -620,6 +666,18 @@ def pod5_module(margs): description="For example:\n" "python %(prog)s -i input.bam -o /output_directory/", formatter_class=RawTextHelpFormatter) + +bam_parser.add_argument("--ref", type=str, default="", + help="Reference genome file for the BAM file, used for base modification analysis. Default: None.") + +# Add argument for base modification filtering threshold +bam_parser.add_argument("--modprob", type=float, default=0.5, + help="Base modification filtering threshold. Above/below this value, the base is considered modified/unmodified. Default: 0.5.") + +# Add argument for GTF file required for RNA-seq analysis (TIN, etc.) +bam_parser.add_argument("--gtf", type=str, default="", + help="GTF file required for RNA-seq analysis. Default: None.") + bam_parser.set_defaults(func=bam_module) # RRMS BAM file input (Splits accepted and rejected reads) diff --git a/src/fast5_module.cpp b/src/fast5_module.cpp index 85b4d8b..233fce2 100644 --- a/src/fast5_module.cpp +++ b/src/fast5_module.cpp @@ -416,7 +416,7 @@ static int writeBaseQCDetails(const char *input_file, Output_FAST5 &output_data, // First remove the prefix std::string read_id = read_name.substr(5); - std::cout << "Processing read ID: " << read_id << std::endl; + // std::cout << "Processing read ID: " << read_id << std::endl; //std::cout << "Read: " << read_name << std::endl; // Set up the analysis and basecall group @@ -512,16 +512,13 @@ static int writeSignalQCDetails(const char *input_file, Output_FAST5 &output_dat //std::cout << "Skipping read ID: " << read_id << std::endl; continue; } else { - std::cout << "Processing read ID: " << read_id << std::endl; + // std::cout << "Processing read ID: " << read_id << std::endl; } } // std::cout << "Read: " << read_name << std::endl; - // Get the basecall signals - // std::cout << "Getting basecall signals" << std::endl; + // Append the basecall signals to the output structure Base_Signals basecall_obj = getReadBaseSignalData(f5, read_name, false); - - //std::cout << "Adding basecall signals" << std::endl; output_data.addReadBaseSignals(basecall_obj); } } diff --git a/src/generate_html.py b/src/generate_html.py index d50955e..7a80511 100644 --- a/src/generate_html.py +++ b/src/generate_html.py @@ -13,6 +13,7 @@ def __init__(self, para_list, plot_filepaths, static=True): self.static = static # Static vs. dynamic webpage boolean self.plot_filepaths = plot_filepaths self.prg_name = self.input_para["prg_name"] # Program name + self.html_writer = None if len(self.input_para["input_files"]) > 1: self.more_input_files = True @@ -20,8 +21,9 @@ def __init__(self, para_list, plot_filepaths, static=True): self.more_input_files = False def generate_header(self): + """Format the header of the HTML file with the title and CSS.""" html_filepath = self.input_para["output_folder"] + '/' + self.input_para["out_prefix"] + ".html" - self.html_writer = open(html_filepath, 'w') + self.html_writer = open(html_filepath, 'w', encoding='utf-8') self.html_writer.write("") self.html_writer.write("") self.html_writer.write("") @@ -225,18 +227,17 @@ def generate_header(self): <div id="header_filename"> <script> document.write(new Date().toLocaleDateString()); </script> '''.format(self.prg_name)) - # for _af in self.input_para["input_files"]: - # self.html_writer.write( "<br/>"+_af); - # self.html_writer.write( "<br/>"+ self.input_para["input_files"][0] ) - self.html_writer.write(''' - </div> - </div>''') + self.html_writer.write('''</div></div>''') def generate_left(self): + """Generate the left section of the HTML file with the links to the + right section.""" + # Add the summary section of links self.html_writer.write('<div class="summary">'); self.html_writer.write('<h2>Summary</h2>') self.html_writer.write('<ul>') + # Add links to the right sections key_index = 0 for plot_key in self.image_key_list: self.html_writer.write('<li>') @@ -246,15 +247,17 @@ def generate_left(self): key_index += 1 self.html_writer.write('</li>') + # Add the input files section link self.html_writer.write('<li>') self.html_writer.write('<a href="#lrst' + str(key_index) + '">Input File List</a>') key_index += 1 self.html_writer.write('</li>') - + self.html_writer.write("</ul>") self.html_writer.write('</div>') def generate_right(self): + """Generate the right section of the HTML file with the plots and tables.""" self.html_writer.write('<div class="main">') key_index = 0 for plot_key in self.image_key_list: @@ -262,16 +265,19 @@ def generate_right(self): self.html_writer.write( '<h2 id="lrst' + str(key_index) + '">' + self.plot_filepaths[plot_key]['description'] + '</h2><p>') - # Add the plot or the HTML summary table - if plot_key == "basic_st": - self.html_writer.write(self.plot_filepaths["basic_st"]['detail']) + # Add the figures + if plot_key == "basic_st" or plot_key == "base_mods": + # Add the HTML tables + self.html_writer.write(self.plot_filepaths[plot_key]['detail']) + else: + # Add the dynamic plots try: dynamic_plot = self.plot_filepaths[plot_key]['dynamic'] self.html_writer.write(dynamic_plot) except KeyError: - logging.error("Missing dynamic plot for " + plot_key) + logging.error("Missing dynamic plot for %s", plot_key) self.html_writer.write('</div>') @@ -287,13 +293,12 @@ def generate_right(self): self.html_writer.write('</div>') - # Generate links in the left panel def generate_left_signal_data(self, read_names): + """Generate the left section of the HTML file with the links to the right section.""" self.html_writer.write('<div class="summary">'); self.html_writer.write('<h2>Summary</h2>') self.html_writer.write('<ul>') - # Add the summary table section link url_index = 0 self.html_writer.write('<li>') @@ -363,7 +368,7 @@ def generate_end(self): self.html_writer.close() # Main function for generating the HTML. - def generate_st_html(self, signal_plots=False): + def generate_html(self, signal_plots=False): if signal_plots: self.generate_header() # Get the signal plots diff --git a/src/hts_reader.cpp b/src/hts_reader.cpp index 180579a..6a29067 100644 --- a/src/hts_reader.cpp +++ b/src/hts_reader.cpp @@ -16,6 +16,12 @@ Class for reading a set number of records from a BAM file. Used for multi-thread #include "hts_reader.h" #include "utils.h" +void HTSReader::addModificationToQueryMap(std::map<int32_t, std::tuple<char, char, double, int>> &base_modifications, int32_t pos, char mod_type, char canonical_base, double likelihood, int strand) +{ + // Add the modification type to the map + base_modifications[pos] = std::make_tuple(mod_type, canonical_base, likelihood, strand); +} + // HTSReader constructor HTSReader::HTSReader(const std::string & bam_file_name){ this->bam_file = hts_open(bam_file_name.c_str(), "r"); @@ -90,15 +96,13 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu // Access the base quality histogram from the output_data object uint64_t *base_quality_distribution = output_data.seq_quality_info.base_quality_distribution; - // Loop through each alignment record in the BAM file // Do QC on each record and store the results in the output_data object - // bool nm_tag_present = false; // Flag to determine if the NM tag is present (for mismatch counting) - bool mod_tag_present = false; // Flag to determine if the base modification tags (MM, ML) are present + bool first_pod5_tag = false; while ((record_count < batch_size) && (exit_code >= 0)) { // Create a record object bam1_t* record = bam_init1(); - // read the next record in a thread-safe manner + // Read the next record in a thread-safe manner read_mutex.lock(); exit_code = sam_read1(this->bam_file, this->header, record); read_mutex.unlock(); @@ -109,76 +113,187 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu break; // error or EOF } + + // Get the read (query) name + std::string query_name = bam_get_qname(record); + + // Determine if this read should be skipped + if (read_ids_present){ + + + // Determine if this read should be skipped + if (read_ids.find(query_name) == read_ids.end()){ + // std::cout << "Skipping read " << query_name << std::endl; + continue; // Skip this read + } + } + + // For POD5 files, corresponding BAM files will have tags for + // indexing signal data for each base (ts, ns, mv). Find the + // tags and store in the output data object + uint8_t *ts_tag = bam_aux_get(record, "ts"); + uint8_t *ns_tag = bam_aux_get(record, "ns"); + uint8_t *mv_tag = bam_aux_get(record, "mv"); + + // Get POD5 signal tag values if they exist + if (mv_tag != NULL && ts_tag != NULL && ns_tag != NULL) { + // Set the atomic flag and print a message if the POD5 tags are + // present + if (!this->has_pod5_tags.test_and_set()) { + printMessage("POD5 tags found (ts, ns, mv)"); + first_pod5_tag = true; + } + + // Get the ts and ns tags + int32_t ts = bam_aux2i(ts_tag); + int32_t ns = bam_aux2i(ns_tag); + // if (first_pod5_tag) { + // printMessage("ts: " + std::to_string(ts) + ", ns: " + std::to_string(ns)); + // } + + // Get the move table (start at 1 to skip the tag type) + int max_print = 15; + int32_t length = bam_auxB_len(mv_tag); + std::vector<int32_t> move_table(length); + std::vector<std::vector<int>> sequence_move_table; // Store the sequence move table with indices + // if (first_pod5_tag) { + // printMessage("Move table length: " + std::to_string(length)); + // } + + int base_signal_length = 0; + + // Iterate over the move table values + int prev_value = 0; + int current_index = ts; + std::vector<int> signal_index_vector; + int move_value = 0; + for (int32_t i = 1; i < length; i++) { + move_value = bam_auxB2i(mv_tag, i); + if (move_value == 1) { + signal_index_vector.push_back(current_index); + } + + current_index++; + } + // Create a tuple and add the read's signal data to the output data + std::string seq_str = ""; + for (int i = 0; i < record->core.l_qseq; i++) { + seq_str += seq_nt16_str[bam_seqi(bam_get_seq(record), i)]; + } + + // Throw an error if the query name is empty + if (query_name.empty()) { + std::cerr << "Error: Query name is empty" << std::endl; + exit_code = 1; + break; + } + output_data.addReadMoveTable(query_name, seq_str, signal_index_vector, ts, ns); + + // if (first_pod5_tag) { + // printMessage("Signal vector length: " + // + std::to_string(signal_index_vector.size()) + ", Sequence string length: " + // + std::to_string(seq_str.length())); + // // printMessage("Base signal length: " + std::to_string(base_signal_length) + ", Sequence string length: " + std::to_string(seq_str.length())); + + // // printMessage("Base vector length: " + std::to_string(sequence_move_table.size())); + // // printMessage("Test count: " + std::to_string(test_count)); + // // printMessage("Sequence string length: " + std::to_string(seq_str.length())); + // } + } + // Follow here to get base modification tags: // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/sam_mods.c // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2274 hts_base_mod_state *state = hts_base_mod_state_alloc(); - if (bam_parse_basemod(record, state) >= 0) { - printMessage("Base modification tags found"); - // std::cout << "Base modification tags found" << std::endl; - mod_tag_present = true; + std::map<int32_t, std::tuple<char, char, double, int>> query_base_modifications; + + // Parse the base modification tags if a primary alignment + read_mutex.lock(); + int ret = bam_parse_basemod(record, state); + read_mutex.unlock(); + if (ret >= 0 && !(record->core.flag & BAM_FSECONDARY) && !(record->core.flag & BAM_FSUPPLEMENTARY) && !(record->core.flag & BAM_FUNMAP)) { + + // Get the chromosome if alignments are present + bool alignments_present = true; + std::string chr; + if (record->core.tid < 0) { + alignments_present = false; + } else { + chr = this->header->target_name[record->core.tid]; + } + + // Get the strand from the alignment flag (hts_base_mod uses 0 for positive and 1 for negative, + // but it always yields 0...) + int strand = (record->core.flag & BAM_FREVERSE) ? 1 : 0; // Iterate over the state object to get the base modification tags // using bam_next_basemod hts_base_mod mods[10]; int n = 0; int pos = 0; + std::vector<int> query_pos; while ((n=bam_next_basemod(record, state, mods, 10, &pos)) > 0) { for (int i = 0; i < n; i++) { - // Struct definition: https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2226 - printMessage("Found base modification at position " + std::to_string(pos)); - printMessage("Modification type: " + std::string(1, mods[i].modified_base)); - printMessage("Canonical base: " + std::string(1, mods[i].canonical_base)); - printMessage("Likelihood: " + std::to_string(mods[i].qual / 256.0)); - printMessage("Strand: " + std::to_string(mods[i].strand)); - - // - // std::cout << "Base modification at position " << pos << std::endl; - // std::cout << "Base modification type: " << mods[i].modified_base << std::endl; - // std::cout << "Base modification likelihood: " << mods[i].qual / 256.0 << std::endl; - // std::cout << "Base modification strand: " << mods[i].strand << std::endl; + // Update the prediction count + output_data.modified_prediction_count++; + + // Note: The modified base value can be a positive char (e.g. 'm', + // 'h') (DNA Mods DB) or negative integer (ChEBI ID): + // https://github.com/samtools/hts-specs/issues/741 + // DNA Mods: https://dnamod.hoffmanlab.org/ + // ChEBI: https://www.ebi.ac.uk/chebi/searchId.do?chebiId=CHEBI:21839 + // Header line: + // https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/htslib/sam.h#L2215 + + // TODO: Look into htslib error with missing strand information + + // Determine the probability of the modification (-1 if + // unknown) + double probability = -1; + if (mods[i].qual != -1) { + probability = mods[i].qual / 256.0; + } + + // Add the modification to the query base modifications map + this->addModificationToQueryMap(query_base_modifications, pos, mods[i].modified_base, mods[i].canonical_base, probability, strand); + query_pos.push_back(pos); } } - // Iterating by position - // hts_base_mod mods[10]; - // int n = bam_mods_at_next_pos(record, state, mods, 10); - // for (int i = 0; i < n; i++) { - // std::cout << "Base modification at position " << mods[i].pos << std::endl; - // std::cout << "Base modification type: " << mods[i].type << std::endl; - // std::cout << "Base modification likelihood: " << mods[i].likelihood << std::endl; - // } - - - // // Get the ML tag (base modification likelihoods) from the state - // // object (https://github.com/samtools/htslib/blob/11205a9ba5e4fc39cc8bb9844d73db2a63fb8119/sam_mods.c#L176) - // if (state->ml) { - // std::cout << "ML tag found" << std::endl; - // // std::cout << "ML tag: " << state->ml << std::endl; - // // std::cout << "ML tag length: " << state->ml_len << std::endl; - // // std::cout << "ML tag type: " << state->ml_type << std::endl; - // // std::cout << "ML tag num: " << state->ml_num << std::endl; - // // std::cout << "ML tag num length: " << state->ml_num_len << std::endl; - // // std::cout << "ML tag num type: " << state->ml_num_type << - // // std::endl; - // } - } else { - std::cout << "No base modification tags found" << std::endl; - } - - // Determine if this read should be skipped - if (read_ids_present){ - // Get the alignment's query name (the read name) - std::string query_name = bam_get_qname(record); - //std::cout << "Query name: " << query_name << std::endl; + // Set the atomic flag and print a message if the base modification + // tags are present + if (query_pos.size() > 0 && !this->has_mm_ml_tags.test_and_set()) { + printMessage("Base modification data found (MM, ML tags)"); + } - // Determine if this read should be skipped - if (read_ids.find(query_name) == read_ids.end()){ - // std::cout << "Skipping read " << query_name << std::endl; - continue; // Skip this read + // If alignments are present, get the reference positions of the query positions + if (alignments_present && query_pos.size() > 0) { + // Get the query to reference position mapping + std::map<int, int> query_to_ref_map = this->getQueryToRefMap(record); + std::vector<int> ref_pos(query_pos.size(), -1); + + // Loop through the query and reference positions and add the + // reference positions to the output data + for (size_t i = 0; i < query_pos.size(); i++) { + // Get the reference position from the query to reference + // map + if (query_to_ref_map.find(query_pos[i]) != query_to_ref_map.end()) { + ref_pos[i] = query_to_ref_map[query_pos[i]]; + + // Add the modification to the output data + char mod_type = std::get<0>(query_base_modifications[query_pos[i]]); + char canonical_base = std::get<1>(query_base_modifications[query_pos[i]]); + double likelihood = std::get<2>(query_base_modifications[query_pos[i]]); + int strand = std::get<3>(query_base_modifications[query_pos[i]]); + output_data.add_modification(chr, ref_pos[i], mod_type, canonical_base, likelihood, strand); + } + } } } + // Deallocate the state object + hts_base_mod_state_free(state); + // Determine if this is an unmapped read if (record->core.flag & BAM_FUNMAP) { Basic_Seq_Statistics *basic_qc = &output_data.unmapped_long_read_info; @@ -228,7 +343,18 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu uint8_t *nmTag = bam_aux_get(record, "NM"); if (nmTag != NULL) { num_mismatches = (uint64_t) bam_aux2i(nmTag); - // nm_tag_present = true; + + // Set the atomic flag and print a message if the NM tag is + // present + if (!this->has_nm_tag.test_and_set()) { + printMessage("NM tag found, used NM tag for mismatch count"); + } + } else { + // Set the atomic flag and print a message if the NM tag is + // not present + if (!this->has_nm_tag.test_and_set()) { + printMessage("No NM tag found, using CIGAR for mismatch count"); + } } output_data.num_mismatched_bases += num_mismatches; } @@ -257,18 +383,38 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu } else if (!(record->core.flag & BAM_FSECONDARY || record->core.flag & BAM_FSUPPLEMENTARY)) { output_data.num_primary_alignment++; // Update the number of primary alignments - // Loop through the cigar string and count the number of clipped bases + // Loop through the cigar string, count the number of clipped + // bases, and also get the reference position of the read if + // there is a modification tag uint32_t *cigar = bam_get_cigar(record); + int32_t ref_pos = record->core.pos; + int query_pos = 0; for (uint32_t i = 0; i < record->core.n_cigar; i++) { int cigar_op = bam_cigar_op(cigar[i]); uint64_t cigar_len = (uint64_t)bam_cigar_oplen(cigar[i]); switch (cigar_op) { case BAM_CSOFT_CLIP: output_data.num_clip_bases += cigar_len; + query_pos += cigar_len; // Consumes query bases break; case BAM_CHARD_CLIP: output_data.num_clip_bases += cigar_len; break; + case BAM_CMATCH: + case BAM_CEQUAL: + case BAM_CDIFF: + ref_pos += cigar_len; + query_pos += cigar_len; + break; + case BAM_CINS: + query_pos += cigar_len; + break; + case BAM_CDEL: + ref_pos += cigar_len; + break; + case BAM_CREF_SKIP: + ref_pos += cigar_len; + break; default: break; } @@ -283,18 +429,6 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu // Update the GC content histogram basic_qc->read_gc_content_count.push_back(percent_gc); - // Determine if the base modification tags are present - uint8_t *mmTag = bam_aux_get(record, "MM:Z"); - uint8_t *mlTag = bam_aux_get(record, "ML:B:C"); - // uint8_t *mmTag = bam_aux_get(record, "mm"); - // uint8_t *mlTag = bam_aux_get(record, "ml"); - // uint8_t *mmTag = bam_aux_get(record, "MM"); - // uint8_t *mlTag = bam_aux_get(record, "ML"); - - if (mmTag != NULL || mlTag != NULL) { - mod_tag_present = true; - } - } else { std::cerr << "Error: Unknown alignment type" << std::endl; std::cerr << "Flag: " << record->core.flag << std::endl; @@ -307,22 +441,6 @@ int HTSReader::readNextRecords(int batch_size, Output_BAM & output_data, std::mu record_count++; } - // Print if the NM tag was not present - // if (nm_tag_present) - // { - // std::cout << "NM tag found, used NM tag for mismatch count" << std::endl; - // } else { - // std::cout << "No NM tag found, used CIGAR for mismatch count" << std::endl; - // } - - // Print if the base modification tags were present - if (mod_tag_present) - { - std::cout << "Base modification tags found" << std::endl; - } else { - std::cout << "[test2] No base modification tags found" << std::endl; - } - return exit_code; } @@ -348,3 +466,50 @@ int64_t HTSReader::getNumRecords(const std::string & bam_filename){ return num_reads; } + +// Get the mapping of query positions to reference positions for a given alignment record +std::map<int, int> HTSReader::getQueryToRefMap(bam1_t *record) +{ + std::map<int, int> query_to_ref_map; + + // Initialize the starting reference and query positions + int32_t current_ref_pos = record->core.pos; // Get the reference position + int current_query_pos = 0; + uint32_t *cigar = bam_get_cigar(record); + + // Iterate over the CIGAR operations + int cigar_len = record->core.n_cigar; + for (int i = 0; i < cigar_len; i++) { + int cigar_op = bam_cigar_op(cigar[i]); // Get the CIGAR operation + int op_len = bam_cigar_oplen(cigar[i]); // Get the CIGAR operation length + + switch (cigar_op) { + case BAM_CDIFF: + current_ref_pos += op_len; + current_query_pos += op_len; + break; + case BAM_CMATCH: + case BAM_CEQUAL: + for (int j = 0; j < op_len; j++) { + query_to_ref_map[current_query_pos] = current_ref_pos + 1; // Use 1-indexed positions + current_ref_pos++; + current_query_pos++; + // query_to_ref_map[current_query_pos] = current_ref_pos + 1; // Use 1-indexed positions + } + break; + case BAM_CINS: + case BAM_CSOFT_CLIP: + current_query_pos += op_len; + break; + case BAM_CDEL: + case BAM_CREF_SKIP: + current_ref_pos += op_len; + break; + default: + // Handle unexpected CIGAR operations if needed + break; + } + } + + return query_to_ref_map; +} diff --git a/src/input_parameters.cpp b/src/input_parameters.cpp index 055ca76..fedae49 100644 --- a/src/input_parameters.cpp +++ b/src/input_parameters.cpp @@ -4,23 +4,25 @@ Input_Para::Input_Para(){ // Set default parameters - num_input_files = 0; - threads = 1; - rdm_seed = 1; - downsample_percentage = 100; - other_flags = 0; - user_defined_fastq_base_qual_offset = -1; - rrms_csv = ""; + this->num_input_files = 0; + this->threads = 1; + this->rdm_seed = 1; + this->downsample_percentage = 100; + this->other_flags = 0; + this->user_defined_fastq_base_qual_offset = -1; + this->rrms_csv = ""; + this->ref_genome = ""; + this->base_mod_threshold = 0.5; } Input_Para::~Input_Para(){ } -std::string Input_Para::add_input_file(const std::string& _ip_file){ - if ( num_input_files < MAX_INPUT_FILES){ - input_files[ num_input_files ] = _ip_file; - num_input_files++; +std::string Input_Para::add_input_file(const std::string& input_filepath){ + if (this->num_input_files < MAX_INPUT_FILES){ + this->input_files[ this->num_input_files ] = input_filepath; + this->num_input_files++; return ""; }else{ return "Only "+std::to_string(MAX_INPUT_FILES)+" input files are supported!!"; diff --git a/src/lrst.i b/src/lrst.i index 27ea05c..9def020 100644 --- a/src/lrst.i +++ b/src/lrst.i @@ -36,6 +36,70 @@ lrst.i: SWIG module defining the Python wrapper for our C++ modules $result = list; } +// Map std::map<int32_t, std::map<char, std::tuple<char, double>>> to Python +// dictionary +// %typemap(out) std::map<int32_t, std::map<char, std::tuple<char, double>>> { +// PyObject *dict = PyDict_New(); +// for (auto const &it : $1) { +// PyObject *inner_dict = PyDict_New(); +// for (auto const &inner_it : it.second) { +// PyObject *tuple = PyTuple_Pack(2, +// PyUnicode_FromStringAndSize(&std::get<0>(inner_it.second), 1), +// PyFloat_FromDouble(std::get<1>(inner_it.second))); +// PyDict_SetItem(inner_dict, +// PyUnicode_FromStringAndSize(&inner_it.first, 1), +// tuple); +// } +// PyDict_SetItem(dict, PyLong_FromLong(it.first), inner_dict); +// } +// $result = dict; +// } + +// Map std::map<int32_t, std::tuple<char, char, double, int, bool>> to Python +// dictionary +// %typemap(out) std::map<int32_t, std::tuple<char, char, double, int, bool>> { +// PyObject *dict = PyDict_New(); +// for (auto const &it : $1) { +// PyObject *tuple = PyTuple_Pack(5, +// PyUnicode_FromStringAndSize(&std::get<0>(it.second), 1), +// PyUnicode_FromStringAndSize(&std::get<1>(it.second), 1), +// PyFloat_FromDouble(std::get<2>(it.second)), +// PyLong_FromLong(std::get<3>(it.second)), +// PyBool_FromLong(std::get<4>(it.second))); +// PyDict_SetItem(dict, PyLong_FromLong(it.first), tuple); +// } +// $result = dict; +// } + +// Map std::map<std::string, std::map<int32_t, std::tuple<char, char, double, +// int, bool>>> to Python dictionary +%typemap(out) std::map<std::string, std::map<int32_t, std::tuple<char, char, double, int, bool>>> { + PyObject *dict = PyDict_New(); + for (auto const &it : $1) { + PyObject *inner_dict = PyDict_New(); + for (auto const &inner_it : it.second) { + PyObject *tuple = PyTuple_Pack(5, + PyUnicode_FromStringAndSize(&std::get<0>(inner_it.second), 1), + PyUnicode_FromStringAndSize(&std::get<1>(inner_it.second), 1), + PyFloat_FromDouble(std::get<2>(inner_it.second)), + PyLong_FromLong(std::get<3>(inner_it.second)), + PyBool_FromLong(std::get<4>(inner_it.second))); + PyDict_SetItem(inner_dict, PyLong_FromLong(inner_it.first), tuple); + } + PyDict_SetItem(dict, PyUnicode_FromString(it.first.c_str()), inner_dict); + } + $result = dict; +} + +// Map std::map<char, int> to Python dictionary +%typemap(out) std::map<char, int> { + PyObject *dict = PyDict_New(); + for (auto const &it : $1) { + PyDict_SetItem(dict, PyUnicode_FromStringAndSize(&it.first, 1), PyLong_FromLong(it.second)); + } + $result = dict; +} + %include <std_string.i> %include <stdint.i> %include <std_vector.i> diff --git a/src/output_data.cpp b/src/output_data.cpp index d0ef854..6d0ddb3 100644 --- a/src/output_data.cpp +++ b/src/output_data.cpp @@ -5,6 +5,7 @@ #include <sstream> #include "output_data.h" +#include "utils.h" #include "basic_statistics.h" // Base class for storing error output. @@ -261,7 +262,91 @@ Output_BAM::Output_BAM(){ Output_BAM::~Output_BAM(){ } -void Output_BAM::add(Output_BAM& output_data){ +void Output_BAM::add_modification(std::string chr, int32_t ref_pos, char mod_type, char canonical_base, double likelihood, int strand) +{ + // Add the modification to the map of reference positions + try { + this->base_modifications.at(chr); + } catch (const std::out_of_range& oor) { + this->base_modifications[chr] = std::map<int32_t, Base_Modification>(); + } + + try { + this->base_modifications[chr].at(ref_pos); + + // If the reference position is already in the map, use the modification + // type with the highest likelihood + double previous_likelihood = std::get<2>(this->base_modifications[chr][ref_pos]); + if (likelihood > previous_likelihood){ + this->base_modifications[chr][ref_pos] = std::make_tuple(mod_type, canonical_base, likelihood, strand, false); + } + } catch (const std::out_of_range& oor) { + + // If the reference position is not in the map, add the modification + this->base_modifications[chr][ref_pos] = std::make_tuple(mod_type, canonical_base, likelihood, strand, false); + } + + // Add the modification type to the map of modification types + try { + this->modification_type_counts.at(mod_type); + this->modification_type_counts[mod_type] += 1; + } catch (const std::out_of_range& oor) { + this->modification_type_counts[mod_type] = 1; + } +} + +int Output_BAM::getReadCount() +{ + return this->read_move_table.size(); +} + +void Output_BAM::addReadMoveTable(std::string read_name, std::string sequence_data_str, std::vector<int> signal_index, int start, int end) +{ + Base_Move_Table values(sequence_data_str, signal_index, start, end); + this->read_move_table[read_name] = values; +} + +std::vector<int> Output_BAM::getReadMoveTable(std::string read_id) +{ + try { + this->read_move_table.at(read_id); + } catch (const std::out_of_range& oor) { + std::cerr << "Error: Read name " << read_id << " is not in the move table." << std::endl; + } + return this->read_move_table[read_id].getBaseSignalIndex(); +} + +// Get the read's sequence string +std::string Output_BAM::getReadSequence(std::string read_id) +{ + try { + this->read_move_table.at(read_id); + } catch (const std::out_of_range& oor) { + std::cerr << "Error: Read name " << read_id << " is not in the move table." << std::endl; + } + + Base_Move_Table signal_data = this->read_move_table[read_id]; + std::string sequence_str(signal_data.getSequenceString()); + return sequence_str; +} + +int Output_BAM::getReadSequenceStart(std::string read_id) +{ + return this->read_move_table[read_id].getSequenceStart(); +} + +int Output_BAM::getReadSequenceEnd(std::string read_id) +{ + return this->read_move_table[read_id].getSequenceEnd(); +} + +std::map<std::string, std::map<int32_t, Base_Modification>> Output_BAM::get_modifications() +{ + return this->base_modifications; +} + +void Output_BAM::add(Output_BAM &output_data) +{ this->num_primary_alignment += output_data.num_primary_alignment; this->num_secondary_alignment += output_data.num_secondary_alignment; this->num_supplementary_alignment += output_data.num_supplementary_alignment; @@ -278,13 +363,7 @@ void Output_BAM::add(Output_BAM& output_data){ this->forward_alignment += output_data.forward_alignment; this->reverse_alignment += output_data.reverse_alignment; -// // Resize the base quality vector if it is empty -// if ( this->seq_quality_info.base_quality_distribution.empty() ){ -// this->seq_quality_info.base_quality_distribution.resize( MAX_READ_QUALITY ); -// } - // Update the base quality vector if it is not empty -// if ( !output_data.seq_quality_info.base_quality_distribution.empty() ){ for (int i=0; i<MAX_READ_QUALITY; i++){ this->seq_quality_info.base_quality_distribution[i] += output_data.seq_quality_info.base_quality_distribution[i]; } @@ -300,14 +379,40 @@ void Output_BAM::add(Output_BAM& output_data){ this->long_read_info.add(output_data.mapped_long_read_info); this->long_read_info.add(output_data.unmapped_long_read_info); + + // Update the base modification information + for (auto const &it : output_data.base_modifications) { + std::string chr = it.first; + for (auto const &it2 : it.second) { + int32_t ref_pos = it2.first; + char mod_type = std::get<0>(it2.second); + char canonical_base = std::get<1>(it2.second); + double likelihood = std::get<2>(it2.second); + int strand = std::get<3>(it2.second); + this->add_modification(chr, ref_pos, mod_type, canonical_base, likelihood, strand); + } + } + + // Update base modification counts + this->modified_prediction_count += output_data.modified_prediction_count; + + // Update the map + for ( auto it = output_data.read_move_table.begin(); it != output_data.read_move_table.end(); ++it ){ + std::string read_id = it->first; + std::vector<int> signal_index = it->second.getBaseSignalIndex(); + std::string sequence_data_str = it->second.getSequenceString(); + int start = it->second.getSequenceStart(); + int end = it->second.getSequenceEnd(); + this->addReadMoveTable(read_id, sequence_data_str, signal_index, start, end); + } } void Output_BAM::global_sum(){ + // Calculate the global sums for the basic statistics mapped_long_read_info.global_sum(); unmapped_long_read_info.global_sum(); mapped_seq_quality_info.global_sum(); unmapped_seq_quality_info.global_sum(); - long_read_info.global_sum(); seq_quality_info.global_sum(); @@ -415,10 +520,10 @@ void Output_SeqTxt::global_sum(){ // Base class for storing a read's base signal data Base_Signals::Base_Signals(std::string read_name, std::string sequence_data_str, std::vector<std::vector<int>> basecall_signals) { - this->read_name = read_name; // Update the read name - this->sequence_data_str = sequence_data_str; // Update the sequence string - this->basecall_signals = basecall_signals; // Update values - this->base_count = basecall_signals.size(); // Update read length + this->read_name = read_name; + this->sequence_data_str = sequence_data_str; + this->basecall_signals = basecall_signals; + this->base_count = basecall_signals.size(); } std::vector<std::vector<int>> Base_Signals::getDataVector() { @@ -637,3 +742,35 @@ std::vector<double> Output_FAST5::getNthReadKurtosis(int read_index){ return output; } + +std::string Base_Move_Table::getSequenceString() +{ + return this->sequence_data_str; +} + +std::vector<int> Base_Move_Table::getBaseSignalIndex() +{ + return this->base_signal_index; +} + +int Base_Move_Table::getSequenceStart() +{ + return this->sequence_start; +} + +int Base_Move_Table::getSequenceEnd() +{ + return this->sequence_end; +} + +Base_Move_Table::Base_Move_Table(std::string sequence_data_str, std::vector<int> base_signal_index, int start, int end) +{ + this->sequence_data_str = sequence_data_str; + this->base_signal_index = base_signal_index; + this->sequence_start = start; + this->sequence_end = end; +} + +Base_Move_Table::Base_Move_Table() +{ +} diff --git a/src/plot_utils.py b/src/plot_utils.py index fd7a485..f29d6e8 100644 --- a/src/plot_utils.py +++ b/src/plot_utils.py @@ -333,9 +333,37 @@ def read_avg_base_quality(data, font_size): return fig.to_html(full_html=False, default_height=500, default_width=700) + +def plot_base_modifications(base_modifications): + """Plot the base modifications per location.""" + # Get the modification types + modification_types = list(base_modifications.keys()) + + # Create the figure + fig = go.Figure() + + # Add a trace for each modification type + for mod_type in modification_types: + # Get the modification data + mod_data = base_modifications[mod_type] + + # Create the trace + trace = go.Scatter(x=mod_data['positions'], y=mod_data['counts'], mode='markers', name=mod_type) + + # Add the trace to the figure + fig.add_trace(trace) + + # Update the layout + fig.update_layout(title='Base Modifications', xaxis_title='Position', yaxis_title='Counts', showlegend=True, font=dict(size=PLOT_FONT_SIZE)) + + # Generate the HTML + html_obj = fig.to_html(full_html=False, default_height=500, default_width=700) + + return html_obj + + # Main plot function def plot(output_data, para_dict, file_type): - out_path = para_dict["output_folder"] plot_filepaths = getDefaultPlotFilenames() # Get the font size for plotly plots @@ -344,6 +372,17 @@ def plot(output_data, para_dict, file_type): # Create the summary table create_summary_table(output_data, plot_filepaths, file_type) + # Create the modified base table if available + if file_type == 'BAM' and output_data.modified_base_count > 0: + base_modification_threshold = para_dict["modprob"] + create_modified_base_table(output_data, plot_filepaths, base_modification_threshold) + + # Check if the modified base table is available + if 'base_mods' in plot_filepaths: + logging.info("SUCCESS: Modified base table created") + else: + logging.warning("WARNING: Modified base table not created") + # Generate plots plot_filepaths['base_counts']['dynamic'] = plot_base_counts(output_data, file_type) plot_filepaths['basic_info']['dynamic'] = plot_basic_info(output_data, file_type) @@ -373,6 +412,7 @@ def plot(output_data, para_dict, file_type): plot_filepaths['read_avg_base_quality']['dynamic'] = read_quality_dynamic if file_type == 'BAM': + # Plot read alignment QC plot_filepaths['read_alignments_bar']['dynamic'] = plot_alignment_numbers(output_data) plot_filepaths['base_alignments_bar']['dynamic'] = plot_errors(output_data) @@ -381,22 +421,21 @@ def plot(output_data, para_dict, file_type): return plot_filepaths -def plot_pod5(output_dict, para_dict): +def plot_pod5(pod5_output, para_dict, bam_output=None): """Plot the ONT POD5 signal data for a random sample of reads.""" out_path = para_dict["output_folder"] plot_filepaths = getDefaultPlotFilenames() - # Get the font size for plotly plots font_size = para_dict["fontsize"] # Create the summary table - create_pod5_table(output_dict, plot_filepaths) + create_pod5_table(pod5_output, plot_filepaths) # Generate the signal plots marker_size = para_dict["markersize"] read_count_max = para_dict["read_count"] - read_count = len(output_dict.keys()) + read_count = len(pod5_output.keys()) logging.info("Plotting signal data for {} reads".format(read_count)) # Randomly sample a small set of reads if it is a large dataset @@ -409,6 +448,7 @@ def plot_pod5(output_dict, para_dict): else: logging.info("Plotting signal data for all {} reads".format(read_count)) + # Plot the reads output_html_plots = {} for read_index in read_indices: @@ -416,15 +456,15 @@ def plot_pod5(output_dict, para_dict): fig = go.Figure() # Get the read data - nth_read_name = list(output_dict.keys())[read_index] - nth_read_data = output_dict[nth_read_name]['signal'] + nth_read_name = list(pod5_output.keys())[read_index] + nth_read_data = pod5_output[nth_read_name]['signal'] signal_length = len(nth_read_data) logging.info("Signal data count for read {}: {}".format(nth_read_name, signal_length)) - nth_read_mean = output_dict[nth_read_name]['mean'] - nth_read_std = output_dict[nth_read_name]['std'] - nth_read_median = output_dict[nth_read_name]['median'] - nth_read_skewness = output_dict[nth_read_name]['skewness'] - nth_read_kurtosis = output_dict[nth_read_name]['kurtosis'] + nth_read_mean = pod5_output[nth_read_name]['mean'] + nth_read_std = pod5_output[nth_read_name]['std'] + nth_read_median = pod5_output[nth_read_name]['median'] + nth_read_skewness = pod5_output[nth_read_name]['skewness'] + nth_read_kurtosis = pod5_output[nth_read_name]['kurtosis'] # Set up the output CSV csv_qc_filepath = os.path.join(out_path, nth_read_name + '_QC.csv') @@ -432,7 +472,54 @@ def plot_pod5(output_dict, para_dict): qc_writer = csv.writer(qc_file) qc_writer.writerow(["Raw_Signal", "Length", "Mean", "Median", "StdDev", "PearsonSkewnessCoeff", "Kurtosis"]) - # Loop through the data + # Update CSV + raw_row = [nth_read_data, signal_length, nth_read_mean, nth_read_median, nth_read_std, nth_read_skewness, nth_read_kurtosis] + qc_writer.writerow(raw_row) + + # Close CSV + qc_file.close() + + # Plot the base sequence if available + if bam_output: + move_table = bam_output.getReadMoveTable(nth_read_name) + read_sequence = bam_output.getReadSequence(nth_read_name) + start_index = bam_output.getReadSequenceStart(nth_read_name) + end_index = bam_output.getReadSequenceEnd(nth_read_name) + + # Print the first couple of indices from the table. + # Each index in the move table represents a k-mer move. Thus, for + # each base, the signal is between two indices in the move table, starting + # from the first index. + logging.info("Move table for read {}: {}".format(nth_read_name, move_table[:5])) + logging.info("Move table range: {}-{}".format(min(move_table), max(move_table))) + logging.info("Read sequence for read {}: {}".format(nth_read_name, read_sequence[:5])) + logging.info("Read sequence length for read {}: {}".format(nth_read_name, len(read_sequence))) + logging.info("Signal data length for read {}: {}".format(nth_read_name, len(move_table))) + logging.info("Signal interval for read {}: {}-{}".format(nth_read_name, start_index, end_index)) + + # Filter the signal data. Use the last index of the move table + 20 + # as the end index, since the signal data can be much longer than the + # read sequence. + end_index = max(move_table) + 20 + nth_read_data = nth_read_data[start_index:end_index] + signal_length = len(nth_read_data) + + # Set up the X tick values + base_tick_values = move_table + + # Set up the X tick labels + x_tick_labels = list(read_sequence) + + # Update the plot style + fig.update_xaxes(title="Base", + tickangle=0, + tickmode='array', + tickvals=base_tick_values, + ticktext=x_tick_labels) + else: + fig.update_xaxes(title="Index") + + # Plot the signal data x = np.arange(signal_length) fig.add_trace(go.Scatter( x=x, y=nth_read_data, @@ -441,13 +528,6 @@ def plot_pod5(output_dict, para_dict): size=5, line=dict(color='MediumPurple', width=2)), opacity=0.5)) - - # Update CSV - raw_row = [nth_read_data, signal_length, nth_read_mean, nth_read_median, nth_read_std, nth_read_skewness, nth_read_kurtosis] - qc_writer.writerow(raw_row) - - # Close CSV - qc_file.close() # Update the plot style fig.update_layout( @@ -457,7 +537,7 @@ def plot_pod5(output_dict, para_dict): font=dict(size=PLOT_FONT_SIZE) ) fig.update_traces(marker={'size': marker_size}) - fig.update_xaxes(title="Index") + # fig.update_xaxes(title="Index") # Append the dynamic HTML object to the output structure dynamic_html = fig.to_html(full_html=False) @@ -587,8 +667,8 @@ def plot_signal(output_data, para_dict): return output_html_plots -# Create a summary table for the basic statistics from the C++ output data def create_summary_table(output_data, plot_filepaths, file_type): + """Create the summary table for the basic statistics.""" plot_filepaths["basic_st"] = {} plot_filepaths["basic_st"]['file'] = "" plot_filepaths["basic_st"]['title'] = "Summary Table" @@ -597,10 +677,10 @@ def create_summary_table(output_data, plot_filepaths, file_type): file_type_label = file_type if file_type == 'FAST5s': file_type_label = 'FAST5' - plot_filepaths["basic_st"]['description'] = "{} Basic Statistics".format(file_type_label) if file_type == 'BAM': + # Add alignment statistics to the summary table table_str = "<table>\n<thead>\n<tr><th>Measurement</th><th>Mapped</th><th>Unmapped</th><th>All</th></tr>\n" \ "</thead> " table_str += "\n<tbody>" @@ -701,6 +781,55 @@ def create_summary_table(output_data, plot_filepaths, file_type): table_str += "\n</tbody>\n</table>" plot_filepaths["basic_st"]['detail'] = table_str +def create_modified_base_table(output_data, plot_filepaths, base_modification_threshold): + """Create a summary table for the base modifications.""" + plot_filepaths["base_mods"] = {} + plot_filepaths["base_mods"]['file'] = "" + plot_filepaths["base_mods"]['title'] = "Base Modifications" + plot_filepaths["base_mods"]['description'] = "Base modification statistics" + + # Set up the HTML table with two columns and no header + table_str = "<table>\n<tbody>" + + # Get the base modification statistics + total_predictions = output_data.modified_prediction_count + total_modifications = output_data.modified_base_count + total_forward_modifications = output_data.modified_base_count_forward + total_reverse_modifications = output_data.modified_base_count_reverse + # total_c_modifications = output_data.c_modified_base_count + cpg_modifications = output_data.cpg_modified_base_count + cpg_forward_modifications = output_data.cpg_modified_base_count_forward + cpg_reverse_modifications = output_data.cpg_modified_base_count_reverse + genome_cpg_count = output_data.cpg_genome_count + pct_modified_cpg = output_data.percent_modified_cpg + + # # Get the percentage of Cs in CpG sites + # cpg_modification_percentage = 0 + # if total_c_modifications > 0: + # cpg_modification_percentage = (cpg_modifications / total_c_modifications) * 100 + + # Add the base modification statistics to the table + table_str += "<tr><td>Total Predictions</td><td style=\"text-align:right\">{:,d}</td></tr>".format(total_predictions) + table_str += "<tr><td>Probability Threshold</td><td style=\"text-align:right\">{:.2f}</td></tr>".format(base_modification_threshold) + table_str += "<tr><td>Total Modified Bases in the Genome</td><td style=\"text-align:right\">{:,d}</td></tr>".format(total_modifications) + table_str += "<tr><td>Total in the Forward Strand</td><td style=\"text-align:right\">{:,d}</td></tr>".format(total_forward_modifications) + table_str += "<tr><td>Total in the Reverse Strand</td><td style=\"text-align:right\">{:,d}</td></tr>".format(total_reverse_modifications) + # table_str += "<tr><td>Total Modified Cs</td><td + # style=\"text-align:right\">{:,d}</td></tr>".format(total_c_modifications) + table_str += "<tr><td>Total CpG Sites in the Genome</td><td style=\"text-align:right\">{:,d}</td></tr>".format(genome_cpg_count) + table_str += "<tr><td>Total Modified Cs in CpG Sites (Forward Strand)</td><td style=\"text-align:right\">{:,d}</td></tr>".format(cpg_forward_modifications) + table_str += "<tr><td>Total Modified Cs in CpG Sites (Reverse Strand)</td><td style=\"text-align:right\">{:,d}</td></tr>".format(cpg_reverse_modifications) + table_str += "<tr><td>Total Modified Cs in CpG Sites (Combined Strands)</td><td style=\"text-align:right\">{:,d}</td></tr>".format(cpg_modifications) + table_str += "<tr><td>Percentage of CpG Sites with Modifications (Combined Strands)</td><td style=\"text-align:right\">{:.2f}%</td></tr>".format(pct_modified_cpg) + + # Add percentage of CpG sites with modifications (forward and reverse + # strands) + # table_str += "<tr><td>Percentage of CpG Sites with Modifications (Forward Strand)</td><td style=\"text-align:right\">{:.2f}%</td></tr>".format(pct_modified_cpg_forward) + # table_str += "<tr><td>Percentage of CpG Sites with Modifications (Reverse Strand)</td><td style=\"text-align:right\">{:.2f}%</td></tr>".format(pct_modified_cpg_reverse) + # table_str += "<tr><td>Percentage of Cs in CpG Sites</td><td style=\"text-align:right\">{:.2f}%</td></tr>".format(cpg_modification_percentage) + table_str += "\n</tbody>\n</table>" + plot_filepaths["base_mods"]['detail'] = table_str + def create_pod5_table(output_dict, plot_filepaths): """Create a summary table for the ONT POD5 signal data.""" plot_filepaths["basic_st"] = {} diff --git a/src/pod5_module.py b/src/pod5_module.py index 8d549e4..81466be 100644 --- a/src/pod5_module.py +++ b/src/pod5_module.py @@ -36,7 +36,7 @@ def generate_pod5_qc(input_data: dict) -> dict: if read_id_list and read_id not in read_id_list: logging.info("Skipping read ID: %s", read_id) continue - logging.info("Processing read ID: %s", read_id) + # logging.info("Processing read ID: %s", read_id) # Get the basecall signals read_signal = read.signal diff --git a/src/ref_query.cpp b/src/ref_query.cpp new file mode 100644 index 0000000..1c36662 --- /dev/null +++ b/src/ref_query.cpp @@ -0,0 +1,255 @@ +#include "ref_query.h" + +#include <string.h> +#include <iostream> +#include <unordered_map> +#include <string> +#include <fstream> +#include <sstream> +#include <vector> +#include <algorithm> + + +int RefQuery::setFilepath(std::string fasta_filepath) +{ + if (fasta_filepath == "") + { + std::cout << "No FASTA filepath provided" << std::endl; + return 1; + } + + this->fasta_filepath = fasta_filepath; + + // Parse the FASTA file + std::ifstream fasta_file(fasta_filepath); + if (!fasta_file.is_open()) + { + std::cout << "Could not open FASTA file " << fasta_filepath << std::endl; + exit(1); + } + + // Get the chromosomes and sequences + std::vector<std::string> chromosomes; + std::unordered_map<std::string, std::string> chr_to_seq; + std::string current_chr = ""; + std::string sequence = ""; + std::string line_str = ""; + while (std::getline(fasta_file, line_str)) + { + // Check if the line is a header + if (line_str[0] == '>') + { + // Header line, indicating a new chromosome + // Store the previous chromosome and sequence + if (current_chr != "") + { + chromosomes.push_back(current_chr); // Add the chromosome to the list + chr_to_seq[current_chr] = sequence; // Add the sequence to the map + sequence = ""; // Reset the sequence + } + + // Get the new chromosome + current_chr = line_str.substr(1); + + // Remove the description + size_t space_pos = current_chr.find(" "); + if (space_pos != std::string::npos) + { + current_chr.erase(space_pos); + } + + // Check if the chromosome is already in the map + if (chr_to_seq.find(current_chr) != chr_to_seq.end()) + { + std::cerr << "Duplicate chromosome " << current_chr << std::endl; + exit(1); + } + } else { + // Sequence line + sequence += line_str; + } + } + + // Add the last chromosome at the end of the file + if (current_chr != "") + { + chromosomes.push_back(current_chr); // Add the chromosome to the list + chr_to_seq[current_chr] = sequence; // Add the sequence to the map + } + + // Close the file + fasta_file.close(); + + // Sort the chromosomes + std::sort(chromosomes.begin(), chromosomes.end()); + + // Set the chromosomes and sequences + this->chromosomes = chromosomes; + this->chr_to_seq = chr_to_seq; + + // Find CpG sites + this->generateCpGMap(); + + return 0; +} + +void RefQuery::generateCpGMap() +{ + // Iterate over each chromosome + std::cout << "Locating CpG sites..." << std::endl; + for (const std::string& chr : this->chromosomes) + { + this->chr_to_cpg[chr] = std::unordered_set<int64_t>(); + this->chr_to_cpg_mod[chr] = std::unordered_set<int64_t>(); + + // Iterate over each position in the sequence + const std::string& sequence = this->chr_to_seq[chr]; + for (int32_t pos = 0; pos < (int32_t)sequence.size(); pos++) + { + // Check if the base is a C + if (sequence[pos] == 'C') + { + // Check if the next base is a G + if (pos + 1 < (int32_t)sequence.size() && sequence[pos + 1] == 'G') + { + // Add the CpG site to the map (1-based index) + int32_t pos1 = pos + 1; +// this->chr_pos_to_cpg[chr][pos1] = false; // Initialize as false since no modifications have been found + this->cpg_total_count++; + + this->chr_to_cpg[chr].insert(pos1); + + // Skip the next base + // pos++; + } + } + } + } +} + +void RefQuery::addCpGSiteModification(std::string chr, int64_t pos, int strand) +{ + // Update the CpG site if it exists + // Reverse strand (position is the G in the CpG site, so move back one + // position to get the C position stored in the map) + if (strand == 1) { + pos--; + } + + // Find the CpG site in the map +// if (this->chr_pos_to_cpg[chr].find(pos) != this->chr_pos_to_cpg[chr].end()) + if (this->chr_to_cpg[chr].find(pos) != this->chr_to_cpg[chr].end()) + { + this->test_count++; + + // Update the CpG site if not already modified + if (this->chr_to_cpg_mod[chr].find(pos) == this->chr_to_cpg_mod[chr].end()) + { + this->chr_to_cpg_mod[chr].insert(pos); + this->cpg_modified_count++; + } + + // Update the modified count +// this->cpg_modified_count++; + + // Update the CpG site if not already modified +// if (!this->chr_pos_to_cpg[chr][pos]) +// { +// this->chr_pos_to_cpg[chr][pos] = true; +// this->cpg_modified_count++; +// } +// this->chr_pos_to_cpg[chr][pos] = true; +// this->cpg_modified_count++; + } +} + +std::pair<uint64_t, uint64_t> RefQuery::getCpGModificationCounts(int strand) +{ + uint64_t modified_count = 0; + uint64_t unmodified_count = 0; + + std::cout << "Calculating CpG modification counts..." << std::endl; + + std::cout << " [TEST] Total CpG modified count: " << this->cpg_modified_count << std::endl; + std::cout << " [TEST] Total CpG sites: " << this->cpg_total_count << std::endl; + std::cout << " [TEST] CpG test count: " << this->test_count << std::endl; + modified_count = this->cpg_modified_count; + unmodified_count = this->cpg_total_count - this->cpg_modified_count; + +// // Iterate over each chromosome in the CpG site map +// // uint32_t cpg_site_count = 0; +// for (const auto& chr_pos_map : this->chr_pos_to_cpg) +// { +// // uint32_t chr_cpg_site_count = 0; +// +// // Iterate over each CpG site in the chromosome +// for (const auto& pos_to_cpg : chr_pos_map.second) +// { +// // Get the position and CpG site +// // int64_t pos = pos_to_cpg.first; +// bool is_cpg = pos_to_cpg.second; +// +// // Check if the CpG site is modified +// if (is_cpg) +// { +// // Increment the modified count +// modified_count++; +// } else { +// // Increment the unmodified count +// unmodified_count++; +// } +// // cpg_site_count++; +// // chr_cpg_site_count++; +// } +// } + std::cout << "Modified CpG sites: " << modified_count << std::endl; + std::cout << "Unmodified CpG sites: " << unmodified_count << std::endl; + std::cout << "Total CpG sites: " << modified_count + unmodified_count << std::endl; + std::cout << "Percentage of CpG sites modified: " << (double)modified_count / (modified_count + unmodified_count) * 100 << "%" << std::endl; + return std::make_pair(modified_count, unmodified_count); +} + +std::string RefQuery::getFilepath() +{ + return this->fasta_filepath; +} + +// Function to get the reference sequence at a given position range +std::string RefQuery::getBase(std::string chr, int64_t pos) +{ + // Convert positions from 1-indexed (reference) to 0-indexed (string indexing) + pos--; + + // Ensure that the position is not negative + if (pos < 0) + { + return "N"; + } + + // Get the sequence + const std::string& sequence = this->chr_to_seq[chr]; + + // Check if the position is out of range + if (static_cast<std::size_t>(pos) >= sequence.size()) { + throw std::out_of_range("Index out of range"); + } + + // Get the base + std::string base = sequence.substr(static_cast<std::size_t>(pos), 1); +// return str[static_cast<std::size_t>(index)]; + + // Get the base +// char base = sequence[pos]; + + // If the base is empty, return empty string + if (base == "") + { + return "N"; + } +// if (base == '\0') +// { +// return 'N'; +// } + + return base; +} diff --git a/tests/test_general.py b/tests/test_general.py index 5db2598..3258ac5 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -417,9 +417,7 @@ def test_n50(self, bam_output): @pytest.fixture(scope='class') def unmapped_bam_output(): - """ - Run the BAM module on unmapped inputs. - """ + """Run the BAM module on unmapped inputs.""" # Set parameters default_parameters = lrst.Input_Para() output_folder = os.path.abspath(str("output/")) @@ -436,7 +434,7 @@ def unmapped_bam_output(): # Add input files default_parameters.add_input_file(input_file) - # Run the FASTA statistics module + # Run the BAM statistics module output = lrst.Output_BAM() exit_code = lrst.callBAMModule(default_parameters, output) @@ -444,9 +442,7 @@ def unmapped_bam_output(): class TestUnmappedBAM: - """ - Tests for unmapped BAM inputs. - """ + """Tests for unmapped BAM inputs.""" # Ensure the module ran successfully @pytest.mark.dependency() @@ -486,6 +482,159 @@ def test_n50(self, unmapped_bam_output): assert n50_read_length == 22029 +@pytest.fixture(scope='class') +def forward_base_mod_output(): + """Run the BAM module on a read aligned to the forward strand with base modifications.""" + # Set parameters + default_parameters = lrst.Input_Para() + output_folder = os.path.abspath(str("output/")) + default_parameters.output_folder = output_folder + default_parameters.out_prefix = str("bam_") + default_parameters.base_mod_threshold = -1.0 + + # Check if running remotely + local_dir = os.path.expanduser('~/github/LongReadSum') + if os.getcwd() == local_dir: + input_file = os.path.join(local_dir, "SampleData/forward_mod.bam") # Local path + ref_file = os.path.join(local_dir, "SampleData/chr11.fa") + else: + input_file = os.path.abspath(str("SampleData/forward_mod.bam")) # Remote path + ref_file = os.path.abspath(str("SampleData/chr11.fa")) + + # Add input files + default_parameters.add_input_file(input_file) + default_parameters.ref_genome = ref_file + + # Run the BAM statistics module + output = lrst.Output_BAM() + exit_code = lrst.callBAMModule(default_parameters, output) + + yield [exit_code, output] + + +class TestForwardBaseModBAM: + """Tests for BAM inputs with base modifications on the forward strand.""" + + # Ensure the module ran successfully + @pytest.mark.dependency() + def test_success(self, forward_base_mod_output): + exit_code = forward_base_mod_output[0] + assert exit_code == 0 + + # Tests + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + modified_base_count = output_statistics.modified_base_count + assert modified_base_count == 682 + + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_forward_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + forward_modified_base_count = output_statistics.modified_base_count_forward + assert forward_modified_base_count == 682 + + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_reverse_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + reverse_modified_base_count = output_statistics.modified_base_count_reverse + assert reverse_modified_base_count == 0 + + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_cpg_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + cpg_modified_base_count = output_statistics.cpg_modified_base_count + assert cpg_modified_base_count == 621 + + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_forward_cpg_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + forward_cpg_modified_base_count = output_statistics.cpg_modified_base_count_forward + assert forward_cpg_modified_base_count == 621 + + @pytest.mark.dependency(depends=["TestForwardBaseModBAM::test_success"]) + def test_reverse_cpg_modified_base_count(self, forward_base_mod_output): + output_statistics = forward_base_mod_output[1] + reverse_cpg_modified_base_count = output_statistics.cpg_modified_base_count_reverse + assert reverse_cpg_modified_base_count == 0 + + +@pytest.fixture(scope='class') +def reverse_base_mod_output(): + """Run the BAM module on a read aligned to the reverse strand with base modifications.""" + # Set parameters + default_parameters = lrst.Input_Para() + output_folder = os.path.abspath(str("output/")) + default_parameters.output_folder = output_folder + default_parameters.out_prefix = str("bam_") + default_parameters.base_mod_threshold = -1.0 + + # Check if running remotely + local_dir = os.path.expanduser('~/github/LongReadSum') + if os.getcwd() == local_dir: + input_file = os.path.join(local_dir, "SampleData/reverse_mod.bam") + ref_file = os.path.join(local_dir, "SampleData/chr11.fa") + else: + input_file = os.path.abspath(str("SampleData/reverse_mod.bam")) + ref_file = os.path.abspath(str("SampleData/chr11.fa")) + + # Add input files + default_parameters.add_input_file(input_file) + default_parameters.ref_genome = ref_file + + # Run the BAM statistics module + output = lrst.Output_BAM() + exit_code = lrst.callBAMModule(default_parameters, output) + + yield [exit_code, output] + +class TestReverseBaseModBam: + """Tests for BAM inputs with base modifications on the reverse strand.""" + + # Ensure the module ran successfully + @pytest.mark.dependency() + def test_success(self, reverse_base_mod_output): + exit_code = reverse_base_mod_output[0] + assert exit_code == 0 + + # Tests + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + modified_base_count = output_statistics.modified_base_count + assert modified_base_count == 548 + + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_forward_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + forward_modified_base_count = output_statistics.modified_base_count_forward + assert forward_modified_base_count == 0 + + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_reverse_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + reverse_modified_base_count = output_statistics.modified_base_count_reverse + assert reverse_modified_base_count == 548 + + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_cpg_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + cpg_modified_base_count = output_statistics.cpg_modified_base_count + assert cpg_modified_base_count == 525 + + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_forward_cpg_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + forward_cpg_modified_base_count = output_statistics.cpg_modified_base_count_forward + assert forward_cpg_modified_base_count == 0 + + @pytest.mark.dependency(depends=["TestReverseBaseModBam::test_success"]) + def test_reverse_cpg_modified_base_count(self, reverse_base_mod_output): + output_statistics = reverse_base_mod_output[1] + reverse_cpg_modified_base_count = output_statistics.cpg_modified_base_count_reverse + assert reverse_cpg_modified_base_count == 525 + + # sequencing_summary.txt tests @pytest.fixture(scope='class') def seqtxt_output():