Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #723

Merged
merged 2 commits into from
Dec 23, 2024
Merged

fix #723

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 89 additions & 49 deletions utils/speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,76 @@

import m3u8
from aiohttp import ClientSession, TCPConnector
from multidict import CIMultiDictProxy

from utils.config import config
from utils.tools import is_ipv6, remove_cache_info


async def get_speed_with_download(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]:
async def get_speed_with_download(url: str, session: ClientSession = None, timeout: int = config.sort_timeout) -> dict[
str, float | None]:
"""
Get the speed of the url with a total timeout
"""
start_time = time()
total_size = 0
total_time = 0
info = {'speed': None, 'delay': None}
if session is None:
session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True)
created_session = True
else:
created_session = False
try:
async with ClientSession(
connector=TCPConnector(ssl=False), trust_env=True
) as session:
async with session.get(url, timeout=timeout) as response:
if response.status == 404:
return info
info['delay'] = int(round((time() - start_time) * 1000))
async for chunk in response.content.iter_any():
if chunk:
total_size += len(chunk)
async with session.get(url, timeout=timeout) as response:
if response.status == 404:
return info
info['delay'] = int(round((time() - start_time) * 1000))
async for chunk in response.content.iter_any():
if chunk:
total_size += len(chunk)
except Exception as e:
pass
finally:
end_time = time()
total_time += end_time - start_time
info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024
return info
if created_session:
await session.close()
end_time = time()
total_time += end_time - start_time
info['speed'] = (total_size / total_time if total_time > 0 else 0) / 1024 / 1024
return info


async def get_m3u8_headers(url: str, session: ClientSession = None, timeout: int = 5) -> CIMultiDictProxy[str] | dict[
any, any]:
"""
Get the headers of the m3u8 url
"""
if session is None:
session = ClientSession(connector=TCPConnector(ssl=False), trust_env=True)
created_session = True
else:
created_session = False
try:
async with session.head(url, timeout=timeout) as response:
return response.headers
except:
pass
finally:
if created_session:
await session.close()
return {}


def check_m3u8_valid(headers: CIMultiDictProxy[str] | dict[any, any]) -> bool:
"""
Check the m3u8 url is valid
"""
content_type = headers.get('Content-Type')
if content_type:
content_type = content_type.lower()
if 'application/vnd.apple.mpegurl' in content_type:
return True
return False


async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[str, float | None]:
Expand All @@ -47,44 +86,45 @@ async def get_speed_m3u8(url: str, timeout: int = config.sort_timeout) -> dict[s
try:
url = quote(url, safe=':/?$&=@[]').partition('$')[0]
async with ClientSession(connector=TCPConnector(ssl=False), trust_env=True) as session:
async with session.head(url, timeout=5) as response:
content_type = response.headers.get('Content-Type')
if content_type:
content_type = content_type.lower()
location = response.headers.get('Location')
if 'application/vnd.apple.mpegurl' in content_type:
url = location or url
headers = await get_m3u8_headers(url, session)
if check_m3u8_valid(headers):
location = headers.get('Location')
if location:
info.update(await get_speed_m3u8(location, timeout))
else:
m3u8_obj = m3u8.load(url, timeout=2)
playlists = m3u8_obj.data.get('playlists')
segments = m3u8_obj.segments
if not segments and playlists:
parsed_url = urlparse(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}"
uri_headers = await get_m3u8_headers(url, session)
if not check_m3u8_valid(uri_headers):
if uri_headers.get('Content-Length'):
info.update(await get_speed_with_download(url, session, timeout))
return info
m3u8_obj = m3u8.load(url, timeout=2)
playlists = m3u8_obj.data.get('playlists')
segments = m3u8_obj.segments
if not segments and playlists:
parsed_url = urlparse(url)
url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit('/', 1)[0]}/{playlists[0].get('uri', '')}"
m3u8_obj = m3u8.load(url, timeout=2)
segments = m3u8_obj.segments
if not segments:
return info
ts_urls = [segment.absolute_uri for segment in segments]
speed_list = []
start_time = time()
for ts_url in ts_urls:
if time() - start_time > timeout:
break
download_info = await get_speed_with_download(ts_url, timeout)
speed_list.append(download_info['speed'])
if info['delay'] is None and download_info['delay'] is not None:
info['delay'] = download_info['delay']
info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0
elif location:
info.update(await get_speed_m3u8(location, timeout))
elif response.headers.get('Content-Length'):
info.update(await get_speed_with_download(url, timeout))
else:
return info
if not segments:
return info
ts_urls = [segment.absolute_uri for segment in segments]
speed_list = []
start_time = time()
for ts_url in ts_urls:
if time() - start_time > timeout:
break
download_info = await get_speed_with_download(ts_url, session, timeout)
speed_list.append(download_info['speed'])
if info['delay'] is None and download_info['delay'] is not None:
info['delay'] = download_info['delay']
info['speed'] = sum(speed_list) / len(speed_list) if speed_list else 0
elif headers.get('Content-Length'):
info.update(await get_speed_with_download(url, session, timeout))
else:
return info
except:
pass
finally:
return info
return info


async def get_delay_requests(url, timeout=config.sort_timeout, proxy=None):
Expand Down