Skip to content

Commit

Permalink
use datasets commandline because API will be deprecated
Browse files Browse the repository at this point in the history
  • Loading branch information
tongzhouxu committed May 29, 2024
1 parent ac47f04 commit 4ad8af4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 12 deletions.
16 changes: 13 additions & 3 deletions mashpit/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import sqlite3
import requests
import psutil
import subprocess

import urllib.request
Expand All @@ -22,6 +23,12 @@
from html.parser import HTMLParser
from sourmash import SourmashSignature, save_signatures, load_one_signature, MinHash

def log_memory_usage(stage):
process = psutil.Process()
mem_info = process.memory_info()
logging.info(f"{stage} - RSS: {mem_info.rss / 1024 ** 2:.2f} MB, VMS: {mem_info.vms / 1024 ** 2:.2f} MB")


# create connection to sqlite db file
def create_connection(sql_path):
conn = sqlite3.connect(sql_path)
Expand Down Expand Up @@ -414,6 +421,7 @@ def prepare(args):
return db_folder,tmp_folder, conn

def build_taxon(args):
log_memory_usage('start')
# prepare the database folder and sqlite database
db_folder,tmp_folder,conn = prepare(args)
# format the pathogen name
Expand All @@ -437,6 +445,7 @@ def build_taxon(args):
logging.info(f'Taxon name validated. Using {pathogen_name} as the taxon name.')
# download the latest PD metadata files and get the PDG accession
metadata_file_name,isolate_pds_file = download_metadata(pathogen_name,args.pd_version,tmp_folder)
log_memory_usage('After downloading files')
df_metadata = pd.read_csv(os.path.join(tmp_folder,metadata_file_name),header=0,sep='\t')
df_isolate_pds = pd.read_csv(os.path.join(tmp_folder,isolate_pds_file),header=0,sep='\t')
# add cluster accession to metadata
Expand All @@ -448,18 +457,19 @@ def build_taxon(args):
calculate_centroid(df_metadata_asm,pdg_acc,tmp_folder)
df_cluster_center = pd.read_csv(os.path.join(tmp_folder,f'{pdg_acc}_cluster_center.tsv'),sep='\t')
df_cluster_center_metadata = df_cluster_center.join(df_metadata_asm.drop('PDS_acc', axis=1).set_index('target_acc'),on='target_acc')
log_memory_usage('After centroid calculation')
# download the centroid assembly using NCBI datasets
gca_acc_list = df_cluster_center_metadata['asm_acc'].to_list()
hash_number = args.number
kmer_size = args.ksize
download_and_sketch_assembly(gca_acc_list,hash_number,kmer_size,tmp_folder)

log_memory_usage('After downloading assemblies')
# merge all signature files
merge_sig(args,db_folder,tmp_folder)

log_memory_usage('After merging signature files')
# import metadata to database
import_metadata(df_metadata,df_cluster_center_metadata,conn)

log_memory_usage('After importing metadata')
# add description to database
c = conn.cursor()
c.execute("INSERT INTO DESC VALUES ('Type','Taxonomy');")
Expand Down
4 changes: 2 additions & 2 deletions mashpit/mashpit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mashpit import build
from mashpit import update
from mashpit import query
#from mashpit import webserver
from mashpit import webserver


def commandToArgs(commandline):
Expand Down Expand Up @@ -54,7 +54,7 @@ def commandToArgs(commandline):
subparser_query.set_defaults(func=query.query)

# "webserver" subcommand
# subparser_webserver.set_defaults(func=webserver.webserver)
subparser_webserver.set_defaults(func=webserver.webserver)

args = parser.parse_args(commandline)
return args
Expand Down
31 changes: 24 additions & 7 deletions mashpit/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
import os
import glob
import logging
import ntpath
import screed
Expand Down Expand Up @@ -49,7 +50,7 @@ def generate_query_table(conn, sorted_asm_similarity_dict):
if db_type == 'Taxonomy':
pds_list = output_df['PDS_acc'].to_list()
cluster_link = [f'https://www.ncbi.nlm.nih.gov/pathogens/isolates/#{pds}' for pds in pds_list]
output_df['link'] = cluster_link
output_df['SNP_tree_link'] = cluster_link

return output_df

Expand Down Expand Up @@ -107,7 +108,7 @@ def generate_mashtree(output_df,min_similarity,query_name,sig_path,added_annotat
f.write(newick_str)
tree = Phylo.read(f'{query_name}_tree.newick', "newick")
n = len(tree.get_terminals())
fig = plt.figure(figsize=(10, n*0.35), dpi=300)
fig = plt.figure(figsize=(10, n*0.2), dpi=150)
axes = fig.add_subplot(1, 1, 1)
# disable the axes and borders
axes.set_frame_on(False)
Expand Down Expand Up @@ -147,11 +148,27 @@ def query(args):
if not os.path.exists(db_folder):
logging.error('Database path not found.')
exit(1)
folder_name = os.path.basename(db_folder)
sql_path = os.path.join(db_folder,f'{folder_name}.db')
sig_path = os.path.join(db_folder, f'{folder_name}.sig')
if not (os.path.exists(sql_path) and os.path.exists(sig_path)):
logging.error('Database incomplete.')
# find files ending with .db
sql_path = glob.glob(os.path.join(db_folder,'*.db'))
if len(sql_path) == 0:
logging.error('Database path not found.')
exit(1)
if len(sql_path) > 1:
logging.error('Multiple database files found.')
exit(1)
sql_path = sql_path[0]
# find files ending with .sig
sig_path = glob.glob(os.path.join(db_folder,'*.sig'))
if len(sig_path) == 0:
logging.error('Signature file not found.')
exit(1)
if len(sig_path) > 1:
logging.error('Multiple signature files found.')
exit(1)
sig_path = sig_path[0]
# check if the basename of .db and .sig are the same
if os.path.basename(sql_path).split('.')[0] != os.path.basename(sig_path).split('.')[0]:
logging.error('Database sql and signature files do not match.')
exit(1)

# check the hash number and kmer size in the database
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
'tqdm',
'flask',
'dask[dataframe]',
'ncbi-datasets-cli',
]
)

0 comments on commit 4ad8af4

Please sign in to comment.