From 37cb8a778d507e076078255194800673cd942c16 Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:47:18 -0800 Subject: [PATCH] Cherrypick: nccl ops multi gpu (#3342) Signed-off-by: Naren Dasan Co-authored-by: Naren Dasan <1790613+narendasan@users.noreply.github.com> --- examples/distributed_inference/README.md | 34 + .../distributed_inference/requirement.txt | 3 +- .../tensor_parallel_initialize_dist.py | 67 ++ .../tensor_parallel_llama3.py | 26 +- .../tensor_parallel_simple_example.py | 24 +- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 3 + py/torch_tensorrt/dynamo/backend/backends.py | 56 +- .../dynamo/conversion/__init__.py | 8 +- .../dynamo/conversion/aten_ops_converters.py | 9 +- .../dynamo/conversion/converter_utils.py | 74 +- .../conversion/custom_ops_converters.py | 61 + .../dynamo/conversion/impl/__init__.py | 1 + .../conversion/impl/elementwise/base.py | 4 +- .../dynamo/conversion/impl/nccl_ops.py | 120 ++ .../dynamo/conversion/impl/select.py | 4 +- .../dynamo/lowering/passes/__init__.py | 1 + .../lowering/passes/_aten_lowering_pass.py | 2 + .../passes/_modify_reshape_complex_nodes.py | 105 ++ .../_replace_complex_placeholder_to_tuple.py | 112 ++ .../lowering/passes/fuse_distributed_ops.py | 80 ++ .../dynamo/lowering/passes/pass_utils.py | 39 + pyproject.toml | 3 +- uv.lock | 1072 ++++++++++++++++- 24 files changed, 1845 insertions(+), 64 deletions(-) create mode 100644 examples/distributed_inference/tensor_parallel_initialize_dist.py create mode 100644 py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md index 4c88570b6f..d4cf9508e1 100644 --- a/examples/distributed_inference/README.md +++ b/examples/distributed_inference/README.md @@ -14,3 +14,37 @@ See the examples started with `data_parallel` for more details. Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded. torchrun --nproc_per_node=2 tensor_parallel_llama2.py + +3. Tensor parallel distributed inference using nccl ops plugin + + apt install libmpich-dev + + apt install libopenmpi-dev + + #For python3.10 + + pip install tensorrt-llm + + For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there + + #then pip install the tensorrt and torch version compatible with installed torchTRT + + mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py + + #For other python + +4. Tensor parallel distributed llama3 inference using nccl ops plugin + + apt install libmpich-dev + + apt install libopenmpi-dev + +#For python3.10 + + pip install tensorrt-llm + + For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so + + #then pip install the tensorrt and torch version compatible with installed torchTRT + + mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py diff --git a/examples/distributed_inference/requirement.txt b/examples/distributed_inference/requirement.txt index 6d8e0aa9f2..bb3b0f28f1 100644 --- a/examples/distributed_inference/requirement.txt +++ b/examples/distributed_inference/requirement.txt @@ -1,3 +1,4 @@ accelerate transformers -diffusers \ No newline at end of file +diffusers +tensorrt-llm \ No newline at end of file diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py new file mode 100644 index 0000000000..21e4cbc282 --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -0,0 +1,67 @@ +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def find_repo_root(max_depth=10): + dir_path = os.path.dirname(os.path.realpath(__file__)) + for i in range(max_depth): + files = os.listdir(dir_path) + if "MODULE.bazel" in files: + return dir_path + else: + dir_path = os.path.dirname(dir_path) + + raise RuntimeError("Could not find repo root") + + +def initialize_logger(rank, logger_file_name): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["TRTLLM_PLUGINS_PATH"] = ( + find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" + ) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + logger = initialize_logger(rank, logger_file_name) + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank, logger diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index fc03a64386..998c378be2 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -5,27 +5,25 @@ import time import torch -import torch_tensorrt from llama3_model import ModelArgs, ParallelTransformer +from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed._tensor import Replicate, Shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, ) -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -_rank = int(os.environ["RANK"]) -_world_size = int(os.environ["WORLD_SIZE"]) -tp_size = 2 - -logger = logging.getLogger() -logger.setLevel(logging.INFO) -fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w") -fh.setLevel(logging.INFO) -logger.addHandler(fh) +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_llama3" +) +# Import should be after initialization of the TRT-LLM plugin .so path +import tensorrt_llm -tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) +logger.info(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" model_args = ModelArgs( vocab_size=32000, @@ -38,7 +36,7 @@ ) with torch.no_grad(): - model = ParallelTransformer(model_args, tp_mesh) + model = ParallelTransformer(model_args, device_mesh) torch.manual_seed(0) inp = torch.randint(32000, (8, 256), device="cuda") python_result = model(inp) @@ -53,7 +51,7 @@ "use_python_runtime": True, "workspace_size": 1 << 33, "debug": False, - "timing_cache_path": "/opt/file/cache/timing_cache_llama.bin", + "use_aot_joint_export": False, }, dynamic=False, ) diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 470487a751..837648fdb4 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,18 +1,22 @@ -import os -import sys import time +import tensorrt as trt import torch import torch.nn as nn import torch_tensorrt +from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._tensor import Shard -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, parallelize_module, ) +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_simple_example" +) +import tensorrt_llm + """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """ @@ -36,14 +40,7 @@ def forward(self, x): return x -# create a device mesh based on the given world_size. -_world_size = int(os.environ["WORLD_SIZE"]) - -device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) -_rank = device_mesh.get_rank() - - -print(f"Starting PyTorch TP example on rank {_rank}.") +logger.info(f"Starting PyTorch TP example on rank {_rank}.") assert ( _world_size % 2 == 0 ), f"TP examples require even number of GPUs, but got {_world_size} gpus" @@ -78,6 +75,7 @@ def forward(self, x): "enabled_precisions": {torch.float32, torch.float16}, "use_python_runtime": True, "min_block_size": 1, + "use_aot_joint_export": False, }, dynamic=False, ) @@ -91,9 +89,9 @@ def forward(self, x): output = tp_model(inp) end = time.time() if i == 0: - print(f"Compilation time is {end-start}") + logger.info(f"Compilation time is {end-start}") assert ( python_result - output ).std() < 0.01, "Compilation result is not correct." elif _rank == 0: - print(f"Inference time is {end-start}") + logger.info(f"Inference time is {end-start}") diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 76630a75a5..18932e6cd0 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -46,6 +46,7 @@ IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False +USE_AOT_JOINT_EXPORT = True def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 7a22663af3..05fb5ce094 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -33,6 +33,7 @@ STRIP_ENGINE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, + USE_AOT_JOINT_EXPORT, USE_EXPLICIT_TYPING, USE_FAST_PARTITIONER, USE_FP32_ACC, @@ -91,6 +92,7 @@ class CompilationSettings: enable_weight_streaming (bool): Enable weight streaming. enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows + use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -131,6 +133,7 @@ class CompilationSettings: immutable_weights: bool = IMMUTABLE_WEIGHTS enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS + use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 59bd7d011d..ef04745562 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,17 +1,20 @@ from __future__ import annotations +import functools import logging import unittest from typing import Any, Callable, Sequence import torch import torch._dynamo as td +from torch._dynamo.backends.common import aot_autograd from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import aot_export_joint_simple from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo._compiler import compile_module from torch_tensorrt.dynamo.lowering import ( get_decompositions, + modify_reshape_complex_nodes, post_lowering, remove_detach, remove_sym_nodes, @@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any ) -> torch.nn.Module: settings, engine_cache = parse_dynamo_kwargs(kwargs) - return _pretraced_backend(gm, sample_inputs, settings, engine_cache) + if settings.use_aot_joint_export: + return _pretraced_backend(gm, sample_inputs, settings, engine_cache) + logger.debug("Wrapping the backend with aot_autograd\n") + _pretraced_backend_autograd = functools.partial( + _pretraced_backend, settings=settings, engine_cache=engine_cache + ) + settings_aot_autograd = {} + settings_aot_autograd["decompostions"] = get_decompositions( + settings.enable_experimental_decompositions + ) + # This is added since detach lowering leads to alias nodes + # Error - View operation returned a tensor that is the same as the input base tensor + # torch nop_decompositions in torch/_decomp/decompositions.py + if aten.detach in settings_aot_autograd["decompositions"]: + del settings_aot_autograd["decompositions"][aten.detach] + return aot_autograd( + fw_compiler=_pretraced_backend_autograd, + decompositions=get_decompositions(settings.enable_experimental_decompositions), + )(gm, sample_inputs) def _pretraced_backend( @@ -89,15 +110,26 @@ def _pretraced_backend( # Remove detach nodes remove_detach(gm, settings) + complexInputIndices = [] + for i, torch_input in enumerate(torch_inputs): + if torch_inputs[i].dtype == torch.complex64: + complexInputIndices.append(i) + torch_input_real = torch_inputs[i].real + torch_input_imaginary = torch_inputs[i].imag + torch_inputs[i] = torch.stack( + (torch_input_real, torch_input_imaginary), dim=-1 + ) + # Invoke AOTAutograd to translate operators to aten - gm = aot_export_joint_simple( - gm, - sample_inputs, - trace_joint=False, - decompositions=get_decompositions( - settings.enable_experimental_decompositions - ), - ) + if settings.use_aot_joint_export: + gm = aot_export_joint_simple( + gm, + sample_inputs, + trace_joint=False, + decompositions=get_decompositions( + settings.enable_experimental_decompositions + ), + ) logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) @@ -105,6 +137,12 @@ def _pretraced_backend( logger.debug("Lowered Input graph:\n " + str(gm.graph)) + if complexInputIndices: + modify_reshape_complex_nodes(gm, complexInputIndices) + logger.debug( + "Input graph after modifying complex nodes:\n " + str(gm.graph) + ) + torchtrt_inputs = prepare_inputs( torch_inputs, disable_memory_format_check=True ) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 26a428ed8d..f41540a0cf 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,10 @@ -from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters +from . import ( + aten_ops_converters, + custom_ops_converters, + ops_evaluators, + plugins, + prims_ops_converters, +) from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 16d20d3c76..5254c6a0ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2,7 +2,7 @@ import logging import operator -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -16,6 +16,7 @@ has_static_shapes_in_args, ) from torch_tensorrt.dynamo.conversion.converter_utils import ( + args_bounds_check, enforce_tensor_types, get_positive_dim, is_only_operator_on_placeholder, @@ -25,12 +26,6 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -def args_bounds_check( - args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None -) -> Any: - return args[i] if len(args) > i and args[i] is not None else replacement - - def get_ir(target: Target) -> SourceIR: target_module = getattr(target, "__module__", "None") if any( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 937ed27fe4..f4bb4877cc 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,10 +1,11 @@ import collections +import ctypes import functools import logging +import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload import numpy as np -import tensorrt as trt import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Argument, Target @@ -18,6 +19,8 @@ DynamoConverterImplSignature, ) +import tensorrt as trt + from ..types import Shape, TRTDataType, TRTLayer, TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -913,3 +916,72 @@ def set_layer_name( else f"{source_ir}_ops.{target.__name__}" ) layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]" + + +def args_bounds_check( + args: Tuple[Argument, ...], i: int, replacement: Optional[Any] = None +) -> Any: + return args[i] if len(args) > i and args[i] is not None else replacement + + +def load_tensorrt_llm() -> bool: + """ + Attempts to load the TensorRT-LLM plugin and initialize it. + + Returns: + bool: True if the plugin was successfully loaded and initialized, False otherwise. + """ + try: + import tensorrt_llm as trt_llm # noqa: F401 + + _LOGGER.info("TensorRT-LLM successfully imported") + return True + except (ImportError, AssertionError) as e_import_error: + # Check for environment variable for the plugin library path + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") + if not plugin_lib_path: + _LOGGER.warning( + "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", + ) + return False + + _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") + try: + # Load the shared library + handle = ctypes.CDLL(plugin_lib_path) + _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + _LOGGER.error( + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" + f"Ensure the path is correct and the library is compatible", + exc_info=e_os_error, + ) + return False + + try: + # Configure plugin initialization arguments + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + _LOGGER.warning( + "Unable to initialize the TensorRT-LLM plugin library", + exc_info=e_plugin_unavailable, + ) + return False + + try: + # Initialize the plugin + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): + _LOGGER.info("TensorRT-LLM plugin successfully initialized") + return True + else: + _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") + return False + except Exception as e_initialization_error: + _LOGGER.warning( + "Exception occurred during TensorRT-LLM plugin library initialization", + exc_info=e_initialization_error, + ) + return False + return False diff --git a/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py new file mode 100644 index 0000000000..17850fabce --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py @@ -0,0 +1,61 @@ +# mypy: disallow-untyped-decorators=False + +import logging +from typing import Dict, Sequence, Tuple, Union + +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_reduce_scatter_op, +) + +import tensorrt as trt + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +if load_tensorrt_llm(): + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def fused_nccl_gather( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.distributed.nccl_gather( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def fused_nccl_reduce_scatter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[trt.ITensor, Sequence[trt.ITensor]]: + return impl.distributed.nccl_reduce_scatter( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + breakpoint() +else: + _LOGGER.debug( + "Did not load torch.distributed converters since TensorRT-LLM is not available" + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index c1187f0dd9..75f7492591 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -13,6 +13,7 @@ full, grid, matmul, + nccl_ops, normalization, pad, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index 3ce1714a7d..4a4b33abea 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Optional, Union import numpy as np -import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt import _enums @@ -14,12 +13,13 @@ broadcast_to_same_shape, cast_trt_tensor, get_trt_tensor, - broadcast, has_dynamic_shape, set_layer_name, ) from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor +import tensorrt as trt + def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py new file mode 100644 index 0000000000..013268f803 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -0,0 +1,120 @@ +import os +from enum import IntEnum, IntFlag, auto +from typing import Optional, Tuple, Union + +import numpy as np +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name + +import tensorrt as trt + + +# class for AllReduce +class AllReduceStrategy(IntEnum): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync. + """ + + NCCL = 0 + ONESHOT = 1 + TWOSHOT = 2 + AUTO = 3 + + +class AllReduceConfig(IntFlag): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync + """ + + USE_MEMCPY = auto() + PUSH_MODE = auto() + + +def nccl_gather( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllGather", "1", "tensorrt_llm" + ) + assert allgather_plg_creator is not None + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is not None: + world_size = int(_world_size) + else: + raise RuntimeError( + "The WORLD_SIZE env variable is not set in distributed environment" + ) + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + p_dtype = trt.float32 + pf_type = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def nccl_reduce_scatter( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs: Tuple[Argument, ...], +) -> trt.ITensor: + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "ReduceScatter", "1", "tensorrt_llm" + ) + + assert allreduce_plg_creator is not None + + counter = 0 + strategy = AllReduceStrategy.NCCL + config = AllReduceConfig(0) + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is not None: + world_size = int(_world_size) + else: + raise RuntimeError( + "The WORLD_SIZE env variable is not set in distributed environment" + ) + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + + p_dtype = trt.float16 + pf_dtype = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = [group, pf_dtype] + p_strategy = trt.PluginField( + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_strategy) + p_config = trt.PluginField( + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_config) + p_counter = trt.PluginField( + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 + ) + pfc.append(p_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 0d55c5f014..dc5458d2c8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -2,7 +2,6 @@ from typing import Optional, Sequence, Union import numpy as np -import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -16,7 +15,6 @@ to_numpy, ) from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.fx.converters.converter_utils import ( @@ -25,6 +23,8 @@ ) from torch_tensorrt.fx.types import TRTTensor +import tensorrt as trt + _LOGGER: logging.Logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index c0e2803e60..716c6505fe 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,3 +1,4 @@ from ._aten_lowering_pass import * +from ._modify_reshape_complex_nodes import modify_reshape_complex_nodes from .remove_sym_nodes import remove_sym_nodes from .repair_input_aliasing import repair_input_aliasing diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 661c76d3b6..f24fd8ec21 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -6,6 +6,7 @@ from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold +from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention from .pass_manager import DynamoPassManager @@ -22,6 +23,7 @@ constant_fold, repair_input_as_output, fuse_prims_broadcast, + fuse_distributed_ops, replace_max_pool_with_indices, lower_scaled_dot_product_attention, view_to_reshape, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py new file mode 100644 index 0000000000..f8ca1f71b9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/_modify_reshape_complex_nodes.py @@ -0,0 +1,105 @@ +import logging + +import torch + +logger = logging.getLogger(__name__) + +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, + find_complex_nodes, +) + +from ._replace_complex_placeholder_to_tuple import replace_complex_placeholder_to_tuple + + +def tensorrt_complex_mul(args0, args1): + args0_real, args0_imag = torch.ops.aten.split.Tensor(args0, 1, -1) + args1_real, args1_imag = torch.ops.aten.split.Tensor(args1, 1, -1) + + args0_real = torch.ops.aten.squeeze.dim(args0_real, -1) + args0_imag = torch.ops.aten.squeeze.dim(args0_imag, -1) + args1_real = torch.ops.aten.squeeze.dim(args1_real, -1) + args1_imag = torch.ops.aten.squeeze.dim(args1_imag, -1) + + complex_mul_real = torch.ops.aten.sub( + torch.ops.aten.mul(args0_real, args1_real), + torch.ops.aten.mul(args0_imag, args1_imag), + ) + complex_mul_imag = torch.ops.aten.add( + torch.ops.aten.mul(args0_real, args1_imag), + torch.ops.aten.mul(args0_imag, args1_real), + ) + + return torch.ops.aten.stack((complex_mul_real, complex_mul_imag), -1) + + +def remove_complex_real_view_nodes(gm: torch.fx.GraphModule): + modified_graph = False + nodes_to_remove = [] + for node in gm.graph.nodes: + if "view_as_complex" in node.name or "view_as_real" in node.name: + nodes_to_remove.append(node) + + for node in nodes_to_remove: + input_node = node.args[0] if node.args else None + + for other_node in gm.graph.nodes: + new_args = tuple( + input_node if arg is node else arg for arg in other_node.args + ) + other_node.args = new_args + + gm.graph.erase_node(node) + modified_graph = True + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after removing view_as_complex nodes and view_as_real nodes:\n{gm.graph}" + ) + + +def modify_reshape_nodes(gm: torch.fx.GraphModule, complex_nodes): + for node in gm.graph.nodes: + if node in complex_nodes: + # slice and transpose will remain same + if "reshape" in node.name: + new_shape = list(node.args[1]) + [2] + node.args = (node.args[0], tuple(new_shape)) + + +def modify_mul_nodes(gm: torch.fx.GraphModule, complex_nodes): + modified_graph = False + for node in gm.graph.nodes: + if node in complex_nodes: + if "mul" in node.name: + complex_mul_args = (node.args[0], node.args[1]) + with gm.graph.inserting_after(node): + replacement_node = gm.graph.create_node( + op="call_function", + target=tensorrt_complex_mul, + args=complex_mul_args, + ) + node.replace_all_uses_with(replacement_node) + replacement_node.meta.update(node.meta) + modified_graph = True + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after custom complex mul nodes is applied to the graph:\n{gm.graph}" + ) + + +def modify_complex_nodes(gm: torch.fx.GraphModule, complex_nodes): + modify_reshape_nodes(gm, complex_nodes) + remove_complex_real_view_nodes(gm) + modify_mul_nodes(gm, complex_nodes) + + +def modify_reshape_complex_nodes(gm: torch.fx.GraphModule, complexInputIndices): + complex_nodes = find_complex_nodes(gm) + if complex_nodes: + replace_complex_placeholder_to_tuple(gm, complexInputIndices) + modify_complex_nodes(gm, complex_nodes) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py new file mode 100644 index 0000000000..e2edec3d28 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/_replace_complex_placeholder_to_tuple.py @@ -0,0 +1,112 @@ +import logging +from typing import List, Tuple + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.node import _get_qualified_name +from torch_tensorrt.dynamo.conversion.converter_utils import args_bounds_check + +# dead-code elimination, linting, and recompilation for graph, in-place +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def replace_complex_placeholder_to_tuple( + gm: torch.fx.GraphModule, + inputListindices: List[int], +) -> torch.fx.GraphModule: + modified_graph = False + input_arg_list = [f"arg{inputListIndex}_1" for inputListIndex in inputListindices] + for node in gm.graph.nodes: + if node.op == "placeholder" and node.target in input_arg_list: + from torch._subclasses.fake_tensor import FakeTensorMode + + node_shape = node.meta["val"].size() + new_node_shape = node_shape + (2,) + new_node_dtype = None + if node.meta["val"].dtype == torch.complex64: + new_node_dtype = torch.float32 + else: + new_node_dtype = torch.float64 + fake_mode = FakeTensorMode() + + real_tensor = torch.empty(new_node_shape, dtype=new_node_dtype) + with FakeTensorMode() as fake_mode: + new_placeholder_tuple = fake_mode.from_tensor(real_tensor) + node.meta["val"] = new_placeholder_tuple + modified_graph = True + # propagate the meta data change for the downstream ops + # TODO:to check if this is required in all cases + propogate_complex_num_shape_change_till_complex_mul(gm, node, fake_mode) + + # If graph was modified, clean it up + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" + ) + + return gm + + +def infer_slice_shape(node: torch.fx.Node) -> Tuple[int, ...]: + input_shape = node.args[0].meta["val"].shape + slice_args = node.args + dim = slice_args[1] + start = slice_args[2] + end = slice_args[3] + step = args_bounds_check(slice_args, 4, replacement=1) + new_shape = list(input_shape) + new_shape[dim] = (end - start + step - 1) // step + return tuple(new_shape) + + +def infer_reshape_shape(node: torch.fx.Node) -> torch.fx.node.Argument: + return node.args[1] + + +shape_inference_funcs = { + "torch.ops.aten.slice.Tensor": infer_slice_shape, + "torch.ops.aten.reshape.default": infer_reshape_shape, +} + + +# Please note this function is for the use case of Llama model +# with complex placeholder->reshape->slice->complex mul +# Hence mul is the terminating op +def propogate_complex_num_shape_change_till_complex_mul( + node: torch.fx.Node, start_node: torch.fx.Node, fake_mode: FakeTensorMode +) -> None: + visited_nodes = set() + stack = [start_node] + while stack: + node = stack.pop() + if node in visited_nodes: + continue + visited_nodes.add(node) + update_node_meta(node, fake_mode) + for user in node.users: + if ( + user.op == "call_function" + and _get_qualified_name(user.target) == "torch.ops.aten.mul.Tensor" + ): + continue + stack.append(user) + + +def update_node_meta(node: torch.fx.Node, fake_mode: FakeTensorMode) -> None: + op_name = node.name + op_target = node.target + + if node.op == "call_function": + op_target = _get_qualified_name(node.target) + + if op_target in shape_inference_funcs: + new_shape = shape_inference_funcs[op_target](node) + real_tensor = torch.empty(new_shape, dtype=node.meta["val"].dtype) + node.meta["val"] = fake_mode.from_tensor(real_tensor) + else: + print("No shape for the inference function", {op_name}) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py new file mode 100644 index 0000000000..f709f177d6 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings + +# dead-code elimination, linting, and recompilation for graph, in-place +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +# TODO: @apbose make these actual torch custom ops, should allow aot_export +def tensorrt_fused_nccl_all_gather_op( + inp: Any, group_size: int, group_name: str +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_gather_into_tensor.default( + inp, group_size, group_name + ) + ) + + +def tensorrt_fused_nccl_reduce_scatter_op( + inp: Any, reduce_op: str, group_size: int, group_name: str +) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + inp, reduce_op, group_size, group_name + ) + ) + + +def fuse_distributed_ops( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + modified_graph = False + for node in gm.graph.nodes: + if ( + node.target + in ( + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + and len(node.users) == 1 + and list(node.users)[0].target + == torch.ops._c10d_functional.wait_tensor.default + ): + wait_tensor_node = list(node.users)[0] + fused_op = None + if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_all_gather_op, # Define your custom fused function + args=(node.args[0], node.args[1], node.args[2]), + ) + else: + fused_node = gm.graph.create_node( + op="call_function", + target=tensorrt_fused_nccl_reduce_scatter_op, # Define your custom fused function + args=(node.args[0], node.args[1], node.args[2], node.args[3]), + ) + + wait_tensor_node.replace_all_uses_with(fused_node) + fused_node.meta.update(node.meta) + modified_graph = True + gm.graph.erase_node(wait_tensor_node) + gm.graph.erase_node(node) + + # If graph was modified, clean it up + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" + ) + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py index 31a55099c2..1736a234a2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_utils.py @@ -29,3 +29,42 @@ def get_tensor_placeholders( ] return placeholders + + +def find_complex_nodes(gm: torch.fx.GraphModule): + complex_nodes = [] + complexNodes = {} + for node in gm.graph.nodes: + if is_node_complex(node, complexNodes): + complex_nodes.append(node) + return complex_nodes + + +def is_node_complex(node: torch.fx.Node, complexNodes): + if not isinstance(node, torch.fx.Node): + return False + if node.name in complexNodes: + return True + if node.op == "call_function" and node.args is not None: + for arg in node.args: + if isinstance(arg, int): + continue + elif isinstance(arg, (list, tuple)): + for eachNode in arg: + if is_node_complex(eachNode, complexNodes): + complexNodes[node.name] = True + return True + + elif hasattr(arg, "meta") and "val" in arg.meta: + if isinstance(arg.meta["val"], (list, tuple)): + for eachFakeTensorMeta in arg.meta["val"]: + if eachFakeTensorMeta.dtype in ( + torch.complex64, + torch.complex128, + ): + complexNodes[node.name] = True + return True + elif arg.meta["val"].dtype in (torch.complex64, torch.complex128): + complexNodes[node.name] = True + return True + return False diff --git a/pyproject.toml b/pyproject.toml index e2e451b430..0472b39f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,9 +85,10 @@ dev = [ torchvision = [ "torchvision", ] #Leaving torchvisions dependency unconstrained so uv can just install something that should work for the torch we have. TV's on PyT makes it hard to put version constrains in -quantization = ["nvidia-modelopt[deploy,hf,torch]~=0.17.0"] +quantization = ["nvidia-modelopt[deploy,hf,torch]>=0.17.0"] monitoring-tools = ["rich>=13.7.1"] jupyter = ["rich[jupyter]>=13.7.1"] +distributed = ["tensorrt-llm>=0.16.0"] [project.urls] Homepage = "https://pytorch.org/tensorrt" diff --git a/uv.lock b/uv.lock index 56c34c4c14..a8982d6b88 100644 --- a/uv.lock +++ b/uv.lock @@ -40,6 +40,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, ] +[[package]] +name = "anyio" +version = "4.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'windows')" }, + { name = "idna", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "sniffio", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'windows')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/40/318e58f669b1a9e00f5c4453910682e2d9dd594334539c7b7817dabb765f/anyio-4.7.0.tar.gz", hash = "sha256:2f834749c602966b7d456a7567cafcb309f96482b5081d14ac93ccd457f9dd48", size = 177076 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/7a/4daaf3b6c08ad7ceffea4634ec206faeff697526421c20f07628c7372156/anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352", size = 93052 }, +] + [[package]] name = "asttokens" version = "2.4.1" @@ -52,6 +67,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/86/4736ac618d82a20d87d2f92ae19441ebc7ac9e7a581d7e58bbe79233b24a/asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24", size = 27764 }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, +] + +[[package]] +name = "attrs" +version = "24.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/48/c8/6260f8ccc11f0917360fc0da435c5c9c7504e3db174d5a12a1494887b045/attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff", size = 805984 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/aa/ab0f7891a01eeb2d2e338ae8fecbe57fcebea1a24dbb64d45801bfab481d/attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308", size = 63397 }, +] + [[package]] name = "black" version = "24.8.0" @@ -74,6 +107,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/1e/83fa8a787180e1632c3d831f7e58994d7aaf23a0961320d21e84f922f919/black-24.8.0-py3-none-any.whl", hash = "sha256:972085c618ee94f402da1af548a4f218c754ea7e5dc70acb168bfaca4c2542ed", size = 206504 }, ] +[[package]] +name = "build" +version = "1.2.2.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "(os_name == 'nt' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform == 'windows')" }, + { name = "importlib-metadata", marker = "(python_full_version < '3.10.2' and sys_platform == 'linux') or (python_full_version < '3.10.2' and sys_platform == 'windows')" }, + { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pyproject-hooks", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tomli", marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'windows')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/46/aeab111f8e06793e4f0e421fcad593d547fb8313b50990f31681ee2fb1ad/build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7", size = 46701 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/c2/80633736cd183ee4a62107413def345f7e6e3c01563dbca1417363cf957e/build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5", size = 22950 }, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -82,6 +131,55 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/certifi-2024.8.30-py3-none-any.whl" }, ] +[[package]] +name = "cffi" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/cc/4635c320081c78d6ffc2cab0a76025b691a91204f4aa317d568ff9280a2d/cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", size = 426024 }, + { url = "https://files.pythonhosted.org/packages/b6/7b/3b2b250f3aab91abe5f8a51ada1b717935fdaec53f790ad4100fe2ec64d1/cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", size = 448188 }, + { url = "https://files.pythonhosted.org/packages/d3/48/1b9283ebbf0ec065148d8de05d647a986c5f22586b18120020452fff8f5d/cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", size = 455571 }, + { url = "https://files.pythonhosted.org/packages/40/87/3b8452525437b40f39ca7ff70276679772ee7e8b394934ff60e63b7b090c/cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", size = 436687 }, + { url = "https://files.pythonhosted.org/packages/8d/fb/4da72871d177d63649ac449aec2e8a29efe0274035880c7af59101ca2232/cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", size = 446211 }, + { url = "https://files.pythonhosted.org/packages/ab/a0/62f00bcb411332106c02b663b26f3545a9ef136f80d5df746c05878f8c4b/cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", size = 461325 }, + { url = "https://files.pythonhosted.org/packages/36/83/76127035ed2e7e27b0787604d99da630ac3123bfb02d8e80c633f218a11d/cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", size = 438784 }, + { url = "https://files.pythonhosted.org/packages/21/81/a6cd025db2f08ac88b901b745c163d884641909641f9b826e8cb87645942/cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", size = 461564 }, + { url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259 }, + { url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200 }, + { url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235 }, + { url = "https://files.pythonhosted.org/packages/62/12/ce8710b5b8affbcdd5c6e367217c242524ad17a02fe5beec3ee339f69f85/cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", size = 459721 }, + { url = "https://files.pythonhosted.org/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", size = 467242 }, + { url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999 }, + { url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242 }, + { url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604 }, + { url = "https://files.pythonhosted.org/packages/cc/b6/db007700f67d151abadf508cbfd6a1884f57eab90b1bb985c4c8c02b0f28/cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", size = 454803 }, + { url = "https://files.pythonhosted.org/packages/1a/df/f8d151540d8c200eb1c6fba8cd0dfd40904f1b0682ea705c36e6c2e97ab3/cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", size = 478850 }, + { url = "https://files.pythonhosted.org/packages/28/c0/b31116332a547fd2677ae5b78a2ef662dfc8023d67f41b2a83f7c2aa78b1/cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", size = 485729 }, + { url = "https://files.pythonhosted.org/packages/91/2b/9a1ddfa5c7f13cab007a2c9cc295b70fbbda7cb10a286aa6810338e60ea1/cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99", size = 471256 }, + { url = "https://files.pythonhosted.org/packages/b2/d5/da47df7004cb17e4955df6a43d14b3b4ae77737dff8bf7f8f333196717bf/cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", size = 479424 }, + { url = "https://files.pythonhosted.org/packages/0b/ac/2a28bcf513e93a219c8a4e8e125534f4f6db03e3179ba1c45e949b76212c/cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", size = 484568 }, + { url = "https://files.pythonhosted.org/packages/d4/38/ca8a4f639065f14ae0f1d9751e70447a261f1a30fa7547a828ae08142465/cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", size = 488736 }, + { url = "https://files.pythonhosted.org/packages/0e/2d/eab2e858a91fdff70533cab61dcff4a1f55ec60425832ddfdc9cd36bc8af/cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3", size = 454792 }, + { url = "https://files.pythonhosted.org/packages/75/b2/fbaec7c4455c604e29388d55599b99ebcc250a60050610fadde58932b7ee/cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683", size = 478893 }, + { url = "https://files.pythonhosted.org/packages/4f/b7/6e4a2162178bf1935c336d4da8a9352cccab4d3a5d7914065490f08c0690/cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5", size = 485810 }, + { url = "https://files.pythonhosted.org/packages/c7/8a/1d0e4a9c26e54746dc08c2c6c037889124d4f59dffd853a659fa545f1b40/cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4", size = 471200 }, + { url = "https://files.pythonhosted.org/packages/26/9f/1aab65a6c0db35f43c4d1b4f580e8df53914310afc10ae0397d29d697af4/cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd", size = 479447 }, + { url = "https://files.pythonhosted.org/packages/5f/e4/fb8b3dd8dc0e98edf1135ff067ae070bb32ef9d509d6cb0f538cd6f7483f/cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed", size = 484358 }, + { url = "https://files.pythonhosted.org/packages/f1/47/d7145bf2dc04684935d57d67dff9d6d795b2ba2796806bb109864be3a151/cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9", size = 488469 }, + { url = "https://files.pythonhosted.org/packages/ed/65/25a8dc32c53bf5b7b6c2686b42ae2ad58743f7ff644844af7cdb29b49361/cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8", size = 424910 }, + { url = "https://files.pythonhosted.org/packages/42/7a/9d086fab7c66bd7c4d0f27c57a1b6b068ced810afc498cc8c49e0088661c/cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576", size = 447200 }, + { url = "https://files.pythonhosted.org/packages/da/63/1785ced118ce92a993b0ec9e0d0ac8dc3e5dbfbcaa81135be56c69cabbb6/cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87", size = 454565 }, + { url = "https://files.pythonhosted.org/packages/74/06/90b8a44abf3556599cdec107f7290277ae8901a58f75e6fe8f970cd72418/cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0", size = 435635 }, + { url = "https://files.pythonhosted.org/packages/bd/62/a1f468e5708a70b1d86ead5bab5520861d9c7eacce4a885ded9faa7729c3/cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3", size = 445218 }, + { url = "https://files.pythonhosted.org/packages/5b/95/b34462f3ccb09c2594aa782d90a90b045de4ff1f70148ee79c69d37a0a5a/cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595", size = 460486 }, + { url = "https://files.pythonhosted.org/packages/fc/fc/a1e4bebd8d680febd29cf6c8a40067182b64f00c7d105f8f26b5bc54317b/cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a", size = 437911 }, + { url = "https://files.pythonhosted.org/packages/e6/c3/21cab7a6154b6a5ea330ae80de386e7665254835b9e98ecc1340b3a7de9a/cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e", size = 460632 }, +] + [[package]] name = "cfgv" version = "3.4.0" @@ -134,14 +232,26 @@ wheels = [ [[package]] name = "click" -version = "8.1.7" +version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "(platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform == 'windows')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ - { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "click-option-group" +version = "0.5.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/b8/91054601a2e05fd9060cb1baf56be5b24145817b059e078669e1099529c7/click-option-group-0.5.6.tar.gz", hash = "sha256:97d06703873518cc5038509443742b25069a3c7562d1ea72ff08bfadde1ce777", size = 16517 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/75/81ea958bc0f7e410257cb2a42531b93a7695a31930cde87192c010a52c50/click_option_group-0.5.6-py3-none-any.whl", hash = "sha256:38a26d963ee3ad93332ddf782f9259c5bdfe405e73408d943ef5e7d0c3767ec7", size = 12467 }, ] [[package]] @@ -161,6 +271,27 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6" }, ] +[[package]] +name = "colored" +version = "2.2.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2f/98/4d4546307039955eec131cf9538732fb7a28d2db2090cd1d4e4a135829e1/colored-2.2.4.tar.gz", hash = "sha256:595e1dd7f3b472ea5f12af21d2fec8a2ea2cf8f9d93e67180197330b26df9b61", size = 13202 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/d1/548f697f88872321525e294f8863efbdd1c313964b7f94e49ab0dc4f2895/colored-2.2.4-py3-none-any.whl", hash = "sha256:a7069673bd90a35f46cb748d012c17284a0668d2f1c06bc7a51822a2d5ad2112", size = 16109 }, +] + +[[package]] +name = "coloredlogs" +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "humanfriendly", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018 }, +] + [[package]] name = "comm" version = "0.2.2" @@ -173,6 +304,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180 }, ] +[[package]] +name = "cuda-python" +version = "12.6.2.post1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/b4/5905cc801fd02a150204c97e18bbdabe4d167b84f7453d241c080128a576/cuda_python-12.6.2.post1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be1f268aae08d509e4af2d8b8465b74351de39d8f439a5a98caf9b276e027d9b", size = 24497758 }, + { url = "https://files.pythonhosted.org/packages/7d/8e/74c493c3bd3855509bdca3001e84d0b4af499cbc83a71f3b4d42ba827781/cuda_python-12.6.2.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee91c92f34fc140b8e241a2681747cfb4442fa3a9dc817376d3090f6c73a0c0f", size = 24910886 }, + { url = "https://files.pythonhosted.org/packages/ec/a1/9a00f7ace9ca34c8b54a8a21c5406a602daadebdb93fd0bddf886c0ec8e8/cuda_python-12.6.2.post1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20271f7979495435e2192758ca52a227dc70e04af1453442a44d3149b94c4630", size = 25349718 }, + { url = "https://files.pythonhosted.org/packages/4e/f4/9badcb1f59263365a2ce49b09b4a0bfa95c1c102a367c5601bcfe118cd96/cuda_python-12.6.2.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d47ba3464fd890025a6d1929dac5e71f2d8c16d0abfe461123cf8be1f975d3a0", size = 25742930 }, + { url = "https://files.pythonhosted.org/packages/d5/c4/4cc082de59d4a465cd29125853b916d5b229a6961cc173a0f22ee699621b/cuda_python-12.6.2.post1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea1e5b26cdfa424c26f098f6ec17340608b73dddc5fd2059f5b31897eabd9916", size = 24612106 }, + { url = "https://files.pythonhosted.org/packages/95/70/7fc0e460a4d2f24b79854cbe7cd474a5457f4b849dd9cd794fafeb4fb85a/cuda_python-12.6.2.post1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ccda3b9400dd6568b3db118aae1c783ef26f2457d6635197999fa42b46b5187", size = 24750651 }, + { url = "https://files.pythonhosted.org/packages/4f/6e/0583be6c547b81e8a3790b8ea845776c2655641b5cae247aacb0bf5ebf51/cuda_python-12.6.2.post1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ea25bf01b08563b2d9ff207db1f6c9beb21cd79528fc3f818cf1771c6a05306", size = 24473514 }, + { url = "https://files.pythonhosted.org/packages/ab/ab/51c55ca37cc4622e6bf6953a0fb90ea1f871ab5e090fc01082a2bb402443/cuda_python-12.6.2.post1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dfb5bdafdc26b47cd345861e6374843bede99fe81804377404c35c19e59e21d", size = 24912694 }, +] + +[[package]] +name = "datasets" +version = "2.14.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "dill", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "fsspec", extra = ["http"], marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "huggingface-hub", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "multiprocess", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pandas", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pyarrow", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pyyaml", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "requests", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "xxhash", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/69/8cc725b5d38968fd118e4ce56a483b16e75b7793854c1a392ec4a34eeb31/datasets-2.14.4.tar.gz", hash = "sha256:ef29c2b5841de488cd343cfc26ab979bff77efa4d2285af51f1ad7db5c46a83b", size = 2178719 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/f8/38298237d18d4b6a8ee5dfe390e97bed5adb8e01ec6f9680c0ddf3066728/datasets-2.14.4-py3-none-any.whl", hash = "sha256:29336bd316a7d827ccd4da2236596279b20ca2ac78f64c04c9483da7cbc2459b", size = 519335 }, +] + [[package]] name = "decorator" version = "5.1.1" @@ -210,6 +380,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/41/9307e4f5f9976bc8b7fea0b66367734e8faf3ec84bc0d412d8cfabbb66cd/distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784", size = 468850 }, ] +[[package]] +name = "distro" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, +] + +[[package]] +name = "evaluate" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "datasets", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "dill", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "fsspec", extra = ["http"], marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "huggingface-hub", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "multiprocess", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pandas", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "requests", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "xxhash", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/a0/10a56e0939ece94c54276e81459cb4101f46f0e9a6f54fc31a35f64e8854/evaluate-0.4.3.tar.gz", hash = "sha256:3a5700cf83aabee9549264e1e5666f116367c61dbd4d38352015e859a5e2098d", size = 65679 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/e7/cbca9e2d2590eb9b5aa8f7ebabe1beb1498f9462d2ecede5c9fd9735faaf/evaluate-0.4.3-py3-none-any.whl", hash = "sha256:47d8770bdea76e2c2ed0d40189273027d1a41ccea861bcc7ba12d30ec5d1e517", size = 84010 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -246,6 +447,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/39/689391845f5dc48df81b0c22248d5f66919b82da12f2bab1424bc3610529/expecttest-0.1.6-py3-none-any.whl", hash = "sha256:7cf2db203c06f9e3173670ca9d09ac00912e535139afac2c7458c1627b1a3ee6", size = 6535 }, ] +[[package]] +name = "fastapi" +version = "0.115.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "starlette", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/db/5781f19bd30745885e0737ff3fdd4e63e7bc691710f9da691128bb0dc73b/fastapi-0.115.4.tar.gz", hash = "sha256:db653475586b091cb8b2fec2ac54a680ac6a158e07406e1abae31679e8826349", size = 300737 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/f6/af0d1f58f86002be0cf1e2665cdd6f7a4a71cdc8a7a9438cdc9e3b5375fe/fastapi-0.115.4-py3-none-any.whl", hash = "sha256:0b504a063ffb3cf96a5e27dc1bc32c80ca743a2528574f9cdc77daa2d31b4742", size = 94732 }, +] + [[package]] name = "filelock" version = "3.16.1" @@ -255,6 +470,65 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0", size = 16163 }, ] +[[package]] +name = "frozenlist" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8f/ed/0f4cec13a93c02c47ec32d81d11c0c1efbadf4a471e3f3ce7cad366cbbd3/frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817", size = 39930 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/62/594a6829ac5679c25755362a9dc93486a8a45241394564309641425d3ff6/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e79225373c317ff1e35f210dd5f1344ff31066ba8067c307ab60254cd3a78ad5", size = 240946 }, + { url = "https://files.pythonhosted.org/packages/7e/75/6c8419d8f92c80dd0ee3f63bdde2702ce6398b0ac8410ff459f9b6f2f9cb/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9272fa73ca71266702c4c3e2d4a28553ea03418e591e377a03b8e3659d94fa76", size = 264608 }, + { url = "https://files.pythonhosted.org/packages/88/3e/82a6f0b84bc6fb7e0be240e52863c6d4ab6098cd62e4f5b972cd31e002e8/frozenlist-1.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:498524025a5b8ba81695761d78c8dd7382ac0b052f34e66939c42df860b8ff17", size = 261361 }, + { url = "https://files.pythonhosted.org/packages/fd/85/14e5f9ccac1b64ff2f10c927b3ffdf88772aea875882406f9ba0cec8ad84/frozenlist-1.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:92b5278ed9d50fe610185ecd23c55d8b307d75ca18e94c0e7de328089ac5dcba", size = 231649 }, + { url = "https://files.pythonhosted.org/packages/ee/59/928322800306f6529d1852323014ee9008551e9bb027cc38d276cbc0b0e7/frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f3c8c1dacd037df16e85227bac13cca58c30da836c6f936ba1df0c05d046d8d", size = 241853 }, + { url = "https://files.pythonhosted.org/packages/7d/bd/e01fa4f146a6f6c18c5d34cab8abdc4013774a26c4ff851128cd1bd3008e/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f2ac49a9bedb996086057b75bf93538240538c6d9b38e57c82d51f75a73409d2", size = 243652 }, + { url = "https://files.pythonhosted.org/packages/a5/bd/e4771fd18a8ec6757033f0fa903e447aecc3fbba54e3630397b61596acf0/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e66cc454f97053b79c2ab09c17fbe3c825ea6b4de20baf1be28919460dd7877f", size = 241734 }, + { url = "https://files.pythonhosted.org/packages/21/13/c83821fa5544af4f60c5d3a65d054af3213c26b14d3f5f48e43e5fb48556/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3ba5f9a0dfed20337d3e966dc359784c9f96503674c2faf015f7fe8e96798c", size = 260959 }, + { url = "https://files.pythonhosted.org/packages/71/f3/1f91c9a9bf7ed0e8edcf52698d23f3c211d8d00291a53c9f115ceb977ab1/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6321899477db90bdeb9299ac3627a6a53c7399c8cd58d25da094007402b039ab", size = 262706 }, + { url = "https://files.pythonhosted.org/packages/4c/22/4a256fdf5d9bcb3ae32622c796ee5ff9451b3a13a68cfe3f68e2c95588ce/frozenlist-1.5.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:76e4753701248476e6286f2ef492af900ea67d9706a0155335a40ea21bf3b2f5", size = 250401 }, + { url = "https://files.pythonhosted.org/packages/98/a8/d0ac0b9276e1404f58fec3ab6e90a4f76b778a49373ccaf6a563f100dfbc/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0996c66760924da6e88922756d99b47512a71cfd45215f3570bf1e0b694c206a", size = 276357 }, + { url = "https://files.pythonhosted.org/packages/ad/c9/c7761084fa822f07dac38ac29f841d4587570dd211e2262544aa0b791d21/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2fe128eb4edeabe11896cb6af88fca5346059f6c8d807e3b910069f39157869", size = 287516 }, + { url = "https://files.pythonhosted.org/packages/a1/ff/cd7479e703c39df7bdab431798cef89dc75010d8aa0ca2514c5b9321db27/frozenlist-1.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a8ea951bbb6cacd492e3948b8da8c502a3f814f5d20935aae74b5df2b19cf3d", size = 283131 }, + { url = "https://files.pythonhosted.org/packages/59/a0/370941beb47d237eca4fbf27e4e91389fd68699e6f4b0ebcc95da463835b/frozenlist-1.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de537c11e4aa01d37db0d403b57bd6f0546e71a82347a97c6a9f0dcc532b3a45", size = 261320 }, + { url = "https://files.pythonhosted.org/packages/b8/5f/c10123e8d64867bc9b4f2f510a32042a306ff5fcd7e2e09e5ae5100ee333/frozenlist-1.5.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c2623347b933fcb9095841f1cc5d4ff0b278addd743e0e966cb3d460278840d", size = 274877 }, + { url = "https://files.pythonhosted.org/packages/fa/79/38c505601ae29d4348f21706c5d89755ceded02a745016ba2f58bd5f1ea6/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cee6798eaf8b1416ef6909b06f7dc04b60755206bddc599f52232606e18179d3", size = 269592 }, + { url = "https://files.pythonhosted.org/packages/19/e2/39f3a53191b8204ba9f0bb574b926b73dd2efba2a2b9d2d730517e8f7622/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f5f9da7f5dbc00a604fe74aa02ae7c98bcede8a3b8b9666f9f86fc13993bc71a", size = 265934 }, + { url = "https://files.pythonhosted.org/packages/d5/c9/3075eb7f7f3a91f1a6b00284af4de0a65a9ae47084930916f5528144c9dd/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:90646abbc7a5d5c7c19461d2e3eeb76eb0b204919e6ece342feb6032c9325ae9", size = 283859 }, + { url = "https://files.pythonhosted.org/packages/05/f5/549f44d314c29408b962fa2b0e69a1a67c59379fb143b92a0a065ffd1f0f/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:bdac3c7d9b705d253b2ce370fde941836a5f8b3c5c2b8fd70940a3ea3af7f4f2", size = 287560 }, + { url = "https://files.pythonhosted.org/packages/9d/f8/cb09b3c24a3eac02c4c07a9558e11e9e244fb02bf62c85ac2106d1eb0c0b/frozenlist-1.5.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03d33c2ddbc1816237a67f66336616416e2bbb6beb306e5f890f2eb22b959cdf", size = 277150 }, + { url = "https://files.pythonhosted.org/packages/e3/12/2aad87deb08a4e7ccfb33600871bbe8f0e08cb6d8224371387f3303654d7/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feeb64bc9bcc6b45c6311c9e9b99406660a9c05ca8a5b30d14a78555088b0b3a", size = 282647 }, + { url = "https://files.pythonhosted.org/packages/77/f2/07f06b05d8a427ea0060a9cef6e63405ea9e0d761846b95ef3fb3be57111/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:683173d371daad49cffb8309779e886e59c2f369430ad28fe715f66d08d4ab1a", size = 289052 }, + { url = "https://files.pythonhosted.org/packages/bd/9f/8bf45a2f1cd4aa401acd271b077989c9267ae8463e7c8b1eb0d3f561b65e/frozenlist-1.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7d57d8f702221405a9d9b40f9da8ac2e4a1a8b5285aac6100f3393675f0a85ee", size = 291719 }, + { url = "https://files.pythonhosted.org/packages/41/d1/1f20fd05a6c42d3868709b7604c9f15538a29e4f734c694c6bcfc3d3b935/frozenlist-1.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c72000fbcc35b129cb09956836c7d7abf78ab5416595e4857d1cae8d6251a6", size = 267433 }, + { url = "https://files.pythonhosted.org/packages/af/f2/64b73a9bb86f5a89fb55450e97cd5c1f84a862d4ff90d9fd1a73ab0f64a5/frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:000a77d6034fbad9b6bb880f7ec073027908f1b40254b5d6f26210d2dab1240e", size = 283591 }, + { url = "https://files.pythonhosted.org/packages/29/e2/ffbb1fae55a791fd6c2938dd9ea779509c977435ba3940b9f2e8dc9d5316/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d7f5a50342475962eb18b740f3beecc685a15b52c91f7d975257e13e029eca9", size = 273249 }, + { url = "https://files.pythonhosted.org/packages/2e/6e/008136a30798bb63618a114b9321b5971172a5abddff44a100c7edc5ad4f/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:87f724d055eb4785d9be84e9ebf0f24e392ddfad00b3fe036e43f489fafc9039", size = 271075 }, + { url = "https://files.pythonhosted.org/packages/ae/f0/4e71e54a026b06724cec9b6c54f0b13a4e9e298cc8db0f82ec70e151f5ce/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:6e9080bb2fb195a046e5177f10d9d82b8a204c0736a97a153c2466127de87784", size = 285398 }, + { url = "https://files.pythonhosted.org/packages/4d/36/70ec246851478b1c0b59f11ef8ade9c482ff447c1363c2bd5fad45098b12/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b93d7aaa36c966fa42efcaf716e6b3900438632a626fb09c049f6a2f09fc631", size = 294445 }, + { url = "https://files.pythonhosted.org/packages/37/e0/47f87544055b3349b633a03c4d94b405956cf2437f4ab46d0928b74b7526/frozenlist-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:52ef692a4bc60a6dd57f507429636c2af8b6046db8b31b18dac02cbc8f507f7f", size = 280569 }, + { url = "https://files.pythonhosted.org/packages/1f/22/462a3dd093d11df623179d7754a3b3269de3b42de2808cddef50ee0f4f48/frozenlist-1.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6482a5851f5d72767fbd0e507e80737f9c8646ae7fd303def99bfe813f76cf7f", size = 265636 }, + { url = "https://files.pythonhosted.org/packages/80/cf/e075e407fc2ae7328155a1cd7e22f932773c8073c1fc78016607d19cc3e5/frozenlist-1.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44c49271a937625619e862baacbd037a7ef86dd1ee215afc298a417ff3270608", size = 270214 }, + { url = "https://files.pythonhosted.org/packages/a1/58/0642d061d5de779f39c50cbb00df49682832923f3d2ebfb0fedf02d05f7f/frozenlist-1.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12f78f98c2f1c2429d42e6a485f433722b0061d5c0b0139efa64f396efb5886b", size = 273905 }, + { url = "https://files.pythonhosted.org/packages/ab/66/3fe0f5f8f2add5b4ab7aa4e199f767fd3b55da26e3ca4ce2cc36698e50c4/frozenlist-1.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce3aa154c452d2467487765e3adc730a8c153af77ad84096bc19ce19a2400840", size = 250542 }, + { url = "https://files.pythonhosted.org/packages/f6/b8/260791bde9198c87a465224e0e2bb62c4e716f5d198fc3a1dacc4895dbd1/frozenlist-1.5.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b7dc0c4338e6b8b091e8faf0db3168a37101943e687f373dce00959583f7439", size = 267026 }, + { url = "https://files.pythonhosted.org/packages/2e/a4/3d24f88c527f08f8d44ade24eaee83b2627793fa62fa07cbb7ff7a2f7d42/frozenlist-1.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:45e0896250900b5aa25180f9aec243e84e92ac84bd4a74d9ad4138ef3f5c97de", size = 257690 }, + { url = "https://files.pythonhosted.org/packages/de/9a/d311d660420b2beeff3459b6626f2ab4fb236d07afbdac034a4371fe696e/frozenlist-1.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:561eb1c9579d495fddb6da8959fd2a1fca2c6d060d4113f5844b433fc02f2641", size = 253893 }, + { url = "https://files.pythonhosted.org/packages/c6/23/e491aadc25b56eabd0f18c53bb19f3cdc6de30b2129ee0bc39cd387cd560/frozenlist-1.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:df6e2f325bfee1f49f81aaac97d2aa757c7646534a06f8f577ce184afe2f0a9e", size = 267006 }, + { url = "https://files.pythonhosted.org/packages/08/c4/ab918ce636a35fb974d13d666dcbe03969592aeca6c3ab3835acff01f79c/frozenlist-1.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:140228863501b44b809fb39ec56b5d4071f4d0aa6d216c19cbb08b8c5a7eadb9", size = 276157 }, + { url = "https://files.pythonhosted.org/packages/c0/29/3b7a0bbbbe5a34833ba26f686aabfe982924adbdcafdc294a7a129c31688/frozenlist-1.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7707a25d6a77f5d27ea7dc7d1fc608aa0a478193823f88511ef5e6b8a48f9d03", size = 264642 }, + { url = "https://files.pythonhosted.org/packages/08/04/e2fddc92135276e07addbc1cf413acffa0c2d848b3e54cacf684e146df49/frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f253985bb515ecd89629db13cb58d702035ecd8cfbca7d7a7e29a0e6d39af5f", size = 241756 }, + { url = "https://files.pythonhosted.org/packages/c6/52/be5ff200815d8a341aee5b16b6b707355e0ca3652953852238eb92b120c2/frozenlist-1.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04a5c6babd5e8fb7d3c871dc8b321166b80e41b637c31a995ed844a6139942b6", size = 267718 }, + { url = "https://files.pythonhosted.org/packages/88/be/4bd93a58be57a3722fc544c36debdf9dcc6758f761092e894d78f18b8f20/frozenlist-1.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9fe0f1c29ba24ba6ff6abf688cb0b7cf1efab6b6aa6adc55441773c252f7411", size = 263494 }, + { url = "https://files.pythonhosted.org/packages/32/ba/58348b90193caa096ce9e9befea6ae67f38dabfd3aacb47e46137a6250a8/frozenlist-1.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:226d72559fa19babe2ccd920273e767c96a49b9d3d38badd7c91a0fdeda8ea08", size = 232838 }, + { url = "https://files.pythonhosted.org/packages/f6/33/9f152105227630246135188901373c4f322cc026565ca6215b063f4c82f4/frozenlist-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15b731db116ab3aedec558573c1a5eec78822b32292fe4f2f0345b7f697745c2", size = 242912 }, + { url = "https://files.pythonhosted.org/packages/a0/10/3db38fb3ccbafadd80a1b0d6800c987b0e3fe3ef2d117c6ced0246eea17a/frozenlist-1.5.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:366d8f93e3edfe5a918c874702f78faac300209a4d5bf38352b2c1bdc07a766d", size = 244763 }, + { url = "https://files.pythonhosted.org/packages/e2/cd/1df468fdce2f66a4608dffe44c40cdc35eeaa67ef7fd1d813f99a9a37842/frozenlist-1.5.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1b96af8c582b94d381a1c1f51ffaedeb77c821c690ea5f01da3d70a487dd0a9b", size = 242841 }, + { url = "https://files.pythonhosted.org/packages/ee/5f/16097a5ca0bb6b6779c02cc9379c72fe98d56115d4c54d059fb233168fb6/frozenlist-1.5.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:c03eff4a41bd4e38415cbed054bbaff4a075b093e2394b6915dca34a40d1e38b", size = 263407 }, + { url = "https://files.pythonhosted.org/packages/0f/f7/58cd220ee1c2248ee65a32f5b4b93689e3fe1764d85537eee9fc392543bc/frozenlist-1.5.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:50cf5e7ee9b98f22bdecbabf3800ae78ddcc26e4a435515fc72d97903e8488e0", size = 265083 }, + { url = "https://files.pythonhosted.org/packages/62/b8/49768980caabf81ac4a2d156008f7cbd0107e6b36d08a313bb31035d9201/frozenlist-1.5.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e76bfbc72353269c44e0bc2cfe171900fbf7f722ad74c9a7b638052afe6a00c", size = 251564 }, + { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 }, +] + [[package]] name = "fsspec" version = "2024.9.0" @@ -264,6 +538,69 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 }, ] +[package.optional-dependencies] +http = [ + { name = "aiohttp", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] + +[[package]] +name = "h11" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, +] + +[[package]] +name = "h5py" +version = "3.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/0c/5c2b0a88158682aeafb10c1c2b735df5bc31f165bfe192f2ee9f2a23b5f1/h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf", size = 411457 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d0/4bf67c3937a2437c20844165766ddd1a1817ae6b9544c3743050d8e0f403/h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2", size = 5168596 }, + { url = "https://files.pythonhosted.org/packages/85/bc/e76f4b2096e0859225f5441d1b7f5e2041fffa19fc2c16756c67078417aa/h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307", size = 5341537 }, + { url = "https://files.pythonhosted.org/packages/b0/62/e2b1f9723ff713e3bd3c16dfeceec7017eadc21ef063d8b7080c0fcdc58a/h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e", size = 5273038 }, + { url = "https://files.pythonhosted.org/packages/e1/89/118c3255d6ff2db33b062ec996a762d99ae50c21f54a8a6047ae8eda1b9f/h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166", size = 5452688 }, + { url = "https://files.pythonhosted.org/packages/af/52/c604adc06280c15a29037d4aa79a24fe54d8d0b51085e81ed24b2fa995f7/h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834", size = 5194606 }, + { url = "https://files.pythonhosted.org/packages/fa/63/eeaacff417b393491beebabb8a3dc5342950409eb6d7b39d437289abdbae/h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9", size = 5413256 }, + { url = "https://files.pythonhosted.org/packages/8a/4f/b74332f313bfbe94ba03fff784219b9db385e6139708e55b11490149f90a/h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32", size = 5154390 }, + { url = "https://files.pythonhosted.org/packages/1a/57/93ea9e10a6457ea8d3b867207deb29a527e966a08a84c57ffd954e32152a/h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f", size = 5378244 }, + { url = "https://files.pythonhosted.org/packages/7b/f9/e597b5fef05f161c67a18e8c61bf105209fd242f2612b0ad1aff7ecb0b9c/h5py-3.12.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6fdf6d7936fa824acfa27305fe2d9f39968e539d831c5bae0e0d83ed521ad1ac", size = 5190324 }, + { url = "https://files.pythonhosted.org/packages/8e/1d/631c200e6d5d067035c58028f305cf7f29c494ddfb9b9484a907a367c8bd/h5py-3.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84342bffd1f82d4f036433e7039e241a243531a1d3acd7341b35ae58cdab05bf", size = 5361780 }, +] + +[[package]] +name = "httpcore" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "h11", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "certifi", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "httpcore", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "idna", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, +] + [[package]] name = "huggingface-hub" version = "0.25.0" @@ -282,6 +619,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/ce/1f8e61cd63175cc2e79233b954b1c4e85363c788fb3a1fa23c87a25c9b81/huggingface_hub-0.25.0-py3-none-any.whl", hash = "sha256:e2f357b35d72d5012cfd127108c4e14abcd61ba4ebc90a5a374dc2456cb34e12", size = 436429 }, ] +[[package]] +name = "humanfriendly" +version = "10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794 }, +] + [[package]] name = "identify" version = "2.6.1" @@ -389,6 +735,55 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/Jinja2-3.1.4-py3-none-any.whl" }, ] +[[package]] +name = "jiter" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/70/90bc7bd3932e651486861df5c8ffea4ca7c77d28e8532ddefe2abc561a53/jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d", size = 163007 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/17/57acab00507e60bd954eaec0837d9d7b119b4117ff49b8a62f2b646f32ed/jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d", size = 335465 }, + { url = "https://files.pythonhosted.org/packages/74/b9/1a3ddd2bc95ae17c815b021521020f40c60b32137730126bada962ef32b4/jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66", size = 355570 }, + { url = "https://files.pythonhosted.org/packages/78/69/6d29e2296a934199a7d0dde673ecccf98c9c8db44caf0248b3f2b65483cb/jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5", size = 381383 }, + { url = "https://files.pythonhosted.org/packages/22/d7/fbc4c3fb1bf65f9be22a32759b539f88e897aeb13fe84ab0266e4423487a/jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3", size = 390454 }, + { url = "https://files.pythonhosted.org/packages/4d/a0/3993cda2e267fe679b45d0bcc2cef0b4504b0aa810659cdae9737d6bace9/jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08", size = 345039 }, + { url = "https://files.pythonhosted.org/packages/b9/ef/69c18562b4c09ce88fab5df1dcaf643f6b1a8b970b65216e7221169b81c4/jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49", size = 376200 }, + { url = "https://files.pythonhosted.org/packages/4d/17/0b5a8de46a6ab4d836f70934036278b49b8530c292b29dde3483326d4555/jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d", size = 511158 }, + { url = "https://files.pythonhosted.org/packages/6c/b2/c401a0a2554b36c9e6d6e4876b43790d75139cf3936f0222e675cbc23451/jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff", size = 503956 }, + { url = "https://files.pythonhosted.org/packages/e5/69/64058e18263d9a5f1e10f90c436853616d5f047d997c37c7b2df11b085ec/jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0", size = 335506 }, + { url = "https://files.pythonhosted.org/packages/9d/14/b747f9a77b8c0542141d77ca1e2a7523e854754af2c339ac89a8b66527d6/jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f", size = 355849 }, + { url = "https://files.pythonhosted.org/packages/53/e2/98a08161db7cc9d0e39bc385415890928ff09709034982f48eccfca40733/jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099", size = 381700 }, + { url = "https://files.pythonhosted.org/packages/7a/38/1674672954d35bce3b1c9af99d5849f9256ac8f5b672e020ac7821581206/jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74", size = 389710 }, + { url = "https://files.pythonhosted.org/packages/f8/9b/92f9da9a9e107d019bcf883cd9125fa1690079f323f5a9d5c6986eeec3c0/jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586", size = 345553 }, + { url = "https://files.pythonhosted.org/packages/44/a6/6d030003394e9659cd0d7136bbeabd82e869849ceccddc34d40abbbbb269/jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc", size = 376388 }, + { url = "https://files.pythonhosted.org/packages/ad/8d/87b09e648e4aca5f9af89e3ab3cfb93db2d1e633b2f2931ede8dabd9b19a/jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88", size = 511226 }, + { url = "https://files.pythonhosted.org/packages/77/95/8008ebe4cdc82eac1c97864a8042ca7e383ed67e0ec17bfd03797045c727/jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6", size = 504134 }, + { url = "https://files.pythonhosted.org/packages/06/99/a2bf660d8ccffee9ad7ed46b4f860d2108a148d0ea36043fd16f4dc37e94/jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f", size = 334242 }, + { url = "https://files.pythonhosted.org/packages/a7/5f/cea1c17864828731f11427b9d1ab7f24764dbd9aaf4648a7f851164d2718/jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60", size = 356654 }, + { url = "https://files.pythonhosted.org/packages/e9/13/62774b7e5e7f5d5043efe1d0f94ead66e6d0f894ae010adb56b3f788de71/jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57", size = 379967 }, + { url = "https://files.pythonhosted.org/packages/ec/fb/096b34c553bb0bd3f2289d5013dcad6074948b8d55212aa13a10d44c5326/jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e", size = 389252 }, + { url = "https://files.pythonhosted.org/packages/17/61/beea645c0bf398ced8b199e377b61eb999d8e46e053bb285c91c3d3eaab0/jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887", size = 345490 }, + { url = "https://files.pythonhosted.org/packages/d5/df/834aa17ad5dcc3cf0118821da0a0cf1589ea7db9832589278553640366bc/jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d", size = 376991 }, + { url = "https://files.pythonhosted.org/packages/67/80/87d140399d382fb4ea5b3d56e7ecaa4efdca17cd7411ff904c1517855314/jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152", size = 510822 }, + { url = "https://files.pythonhosted.org/packages/5c/37/3394bb47bac1ad2cb0465601f86828a0518d07828a650722e55268cdb7e6/jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29", size = 503730 }, + { url = "https://files.pythonhosted.org/packages/7f/68/805978f2f446fa6362ba0cc2e4489b945695940656edd844e110a61c98f8/jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587", size = 333918 }, + { url = "https://files.pythonhosted.org/packages/b3/99/0f71f7be667c33403fa9706e5b50583ae5106d96fab997fa7e2f38ee8347/jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c", size = 356057 }, + { url = "https://files.pythonhosted.org/packages/8d/50/a82796e421a22b699ee4d2ce527e5bcb29471a2351cbdc931819d941a167/jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18", size = 379790 }, + { url = "https://files.pythonhosted.org/packages/3c/31/10fb012b00f6d83342ca9e2c9618869ab449f1aa78c8f1b2193a6b49647c/jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6", size = 388285 }, + { url = "https://files.pythonhosted.org/packages/c8/81/f15ebf7de57be488aa22944bf4274962aca8092e4f7817f92ffa50d3ee46/jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef", size = 344764 }, + { url = "https://files.pythonhosted.org/packages/b3/e8/0cae550d72b48829ba653eb348cdc25f3f06f8a62363723702ec18e7be9c/jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1", size = 376620 }, + { url = "https://files.pythonhosted.org/packages/b8/50/e5478ff9d82534a944c03b63bc217c5f37019d4a34d288db0f079b13c10b/jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9", size = 510402 }, + { url = "https://files.pythonhosted.org/packages/8e/1e/3de48bbebbc8f7025bd454cedc8c62378c0e32dd483dece5f4a814a5cb55/jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05", size = 503018 }, + { url = "https://files.pythonhosted.org/packages/a0/4c/c02408042e6a7605ec063daed138e07b982fdb98467deaaf1c90950cf2c6/jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0", size = 342875 }, + { url = "https://files.pythonhosted.org/packages/56/4c/b413977c20bbb359b4d6c91d04f7f36fc525af0b7778119815477fc97242/jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d", size = 335344 }, + { url = "https://files.pythonhosted.org/packages/b0/59/51b080519938192edd33b4e8d48adb7e9bf9e0d699ec8b91119b9269fc75/jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c", size = 356298 }, + { url = "https://files.pythonhosted.org/packages/72/bb/828db5ea406916d7b2232be31393f782b0f71bcb0b128750c4a028157565/jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d", size = 381703 }, + { url = "https://files.pythonhosted.org/packages/c0/88/45d33a8728733e161e9783c54d8ecca0fc4c1aa74b1cebea1d97917eddc3/jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9", size = 391281 }, + { url = "https://files.pythonhosted.org/packages/45/3e/142712e0f45c28ad8a678dc8732a78294ce5a36fc694141f772bb827a8f2/jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4", size = 345553 }, + { url = "https://files.pythonhosted.org/packages/36/42/9b463b59fd22687b6da1afcad6c9adc870464a808208651de73f1dbeda09/jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27", size = 377063 }, + { url = "https://files.pythonhosted.org/packages/83/b3/44b1f5cd2e4eb15757eec341b25399da4c90515bb881ef6636b50a8c08a5/jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841", size = 512543 }, + { url = "https://files.pythonhosted.org/packages/46/4e/c695c803aa2b668c057b2dea1cdd7a884d1a819ce610cec0be9666210bfd/jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637", size = 505141 }, +] + [[package]] name = "jupyterlab-widgets" version = "3.0.13" @@ -398,6 +793,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/93/858e87edc634d628e5d752ba944c2833133a28fa87bb093e6832ced36a3e/jupyterlab_widgets-3.0.13-py3-none-any.whl", hash = "sha256:e3cda2c233ce144192f1e29914ad522b2f4c40e77214b0cc97377ca3d323db54", size = 214392 }, ] +[[package]] +name = "lark" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036 }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -451,6 +855,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, ] +[[package]] +name = "mpi4py" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/08/34/8499a92a387d24d0092c38089f8195f13c5c76f0f814126af3fe363e5636/mpi4py-4.0.1.tar.gz", hash = "sha256:f3174b245775d556f4fddb32519a2066ef0592edc810c5b5a59238f9a0a40c89", size = 466179 } + [[package]] name = "mpmath" version = "1.3.0" @@ -459,6 +869,60 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/mpmath-1.3.0-py3-none-any.whl" }, ] +[[package]] +name = "multidict" +version = "6.1.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +dependencies = [ + { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'windows')" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/multidict-6.1.0-py3-none-any.whl" }, +] + +[[package]] +name = "multiprocess" +version = "0.70.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/e0/a77ca96e772e13c828fa52f3ad370d413bef194aeaf78b7c6611870ad815/multiprocess-0.70.15.tar.gz", hash = "sha256:f20eed3036c0ef477b07a4177cf7c1ba520d9a2677870a4f47fe026f0cd6787e", size = 1894495 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/97/08402e5ec72c1cc461df1f2e654c1e55527bb05c22120dcdf59745189df6/multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_i686.whl", hash = "sha256:0eac53214d664c49a34695e5824872db4006b1a465edd7459a251809c3773370", size = 133501 }, + { url = "https://files.pythonhosted.org/packages/71/47/5d12db2427465486c0b336cc67753e8826b756a2e4d4ef6385f4f01b355b/multiprocess-0.70.15-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:1a51dd34096db47fb21fa2b839e615b051d51b97af9a67afbcdaa67186b44883", size = 133501 }, + { url = "https://files.pythonhosted.org/packages/35/a8/36d8d7b3e46b377800d8dec47891cdf05842d1a2366909ae4a0c89fbc5e6/multiprocess-0.70.15-py310-none-any.whl", hash = "sha256:7dd58e33235e83cf09d625e55cffd7b0f0eede7ee9223cdd666a87624f60c21a", size = 134824 }, + { url = "https://files.pythonhosted.org/packages/e7/41/96ac938770ba6e7d5ae1d8c9cafebac54b413549042c6260f0d0a6ec6622/multiprocess-0.70.15-py311-none-any.whl", hash = "sha256:134f89053d82c9ed3b73edd3a2531eb791e602d4f4156fc92a79259590bd9670", size = 135392 }, + { url = "https://files.pythonhosted.org/packages/ca/3f/8354ce12fd13bd5c5bb4722261a10ca1d6e2eb7c1c08fa3d8a4e9dc98f44/multiprocess-0.70.15-py37-none-any.whl", hash = "sha256:f7d4a1629bccb433114c3b4885f69eccc200994323c80f6feee73b0edc9199c5", size = 116276 }, + { url = "https://files.pythonhosted.org/packages/c2/a6/c5cb599d917904878f220a4dbdfdcc4ef291dd3956c35b3b0dc6fc42fb6d/multiprocess-0.70.15-py38-none-any.whl", hash = "sha256:bee9afba476c91f9ebee7beeee0601face9eff67d822e893f9a893725fbd6316", size = 132626 }, + { url = "https://files.pythonhosted.org/packages/c6/c9/820b5ab056f4ada76fbe05bd481a948f287957d6cbfd59e2dd2618b408c1/multiprocess-0.70.15-py39-none-any.whl", hash = "sha256:3e0953f5d52b4c76f1c973eaf8214554d146f2be5decb48e928e55c7a2d19338", size = 133349 }, +] + [[package]] name = "mypy" version = "1.11.2" @@ -647,9 +1111,18 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/cu126/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl" }, ] +[[package]] +name = "nvidia-ml-py" +version = "12.560.30" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/10/5f34de4a71db8b2b7ec4269f4a33287f24c23e2857ea3187c977b7bc3604/nvidia-ml-py-12.560.30.tar.gz", hash = "sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97", size = 39194 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/f3/a69ce0b1a1e12fbf6b2ad9f4c14c9999fdbdf15f2478d210f0fd501ddc98/nvidia_ml_py-12.560.30-py3-none-any.whl", hash = "sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73", size = 40526 }, +] + [[package]] name = "nvidia-modelopt" -version = "0.17.0" +version = "0.21.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cloudpickle", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -663,14 +1136,14 @@ dependencies = [ { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/32/b2/2a688cc56d875a08e3e732af642d0ae0f4a7253dc1a00fd271e2fe1a79e9/nvidia_modelopt-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f646da43a46ddf10eb2c2ddebe49cb1a58e631348808c5640afe217f8ab223ae", size = 4699528 }, - { url = "https://files.pythonhosted.org/packages/b2/b3/98bea42c27fb9ca9f6c502ec3624f2000732f9cc642652b75378f6fed1db/nvidia_modelopt-0.17.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:67cd2faad4c0084864330533112cb231fdf404526086f0803e593f65b7868f47", size = 4502680 }, - { url = "https://files.pythonhosted.org/packages/26/7e/beb7461b6bedf7fff043ded616a520468c992df9d005d3bfd87eab1dab71/nvidia_modelopt-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ee7480fdb7e1e6d22e17092bfbc6abf6140f070853c21662e4b6162aec228feb", size = 5108601 }, - { url = "https://files.pythonhosted.org/packages/43/79/45572053e3e928f32d62fcae8704a2667a646356ac752a15490829398987/nvidia_modelopt-0.17.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:16e6a8df04e6551a9f2059cb5c68608d18b38cb14dec609ac920035ceeccbf98", size = 4966637 }, - { url = "https://files.pythonhosted.org/packages/7e/a9/53a82e52fd0d5c44a6a09f48db30b7e133e393ad1b7cef280ac41a4a8bca/nvidia_modelopt-0.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:557b92dbeeb4d9dedb58d9a6251901870a7458f6e773838efa9155693c5d21e1", size = 5145613 }, - { url = "https://files.pythonhosted.org/packages/4e/59/d3d2c7d59c5b56c86c0c1c55896d6800539f5b1aae7c3ea43331c117fe95/nvidia_modelopt-0.17.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3a07969bb8c684b3e9f9eecc018bc465d0f683d9e9a4792b5eb435b9ad85f4a5", size = 4949031 }, - { url = "https://files.pythonhosted.org/packages/61/00/39c8ad3969003f7ba87a54197626aa822047bd98a0e09428045e9dd00335/nvidia_modelopt-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0da51c4464af883eb722b20679633aabc31fadff6b4bee3c33eb4eda3d920484", size = 4694964 }, - { url = "https://files.pythonhosted.org/packages/41/6b/7eb1d60a4706ee1346990bb7b45bf4244be60e24c52894fd817f5ffb649a/nvidia_modelopt-0.17.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:fdc86a37698cad0d780e91ac07166ef4e3b073e87200116351f0b2c7dcada2bf", size = 4502582 }, + { url = "https://files.pythonhosted.org/packages/ef/69/a49a427ae0839a3cdc5eb1868caf563d74d014d9ec4aa8475e3cf18cb7c6/nvidia_modelopt-0.21.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fef6a2ea3bf51a62c31226825d7d35f883e8d260f4d57a7fcb3b6209431351e", size = 2502694 }, + { url = "https://files.pythonhosted.org/packages/cd/7d/1cd0b8001924e48d54f26c3a92d4afc9340af44e4282090a3f166ebf4943/nvidia_modelopt-0.21.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0d26c3193694f90e8a940133d559d3ce44045a4544b131dd7f29f718110413ca", size = 2590928 }, + { url = "https://files.pythonhosted.org/packages/64/85/2531c5ad2881625fda1f5c4e7def6f53b58b7266f2137e94157cab5f2ea5/nvidia_modelopt-0.21.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b4410bab11e7c12ac5fdbb1a11a8d485b65e85867496ad91e849e449e27a726", size = 2710695 }, + { url = "https://files.pythonhosted.org/packages/a0/95/774371981a2f564dc4855c18d21f861084ceb686bac751ef6a0d1b03993c/nvidia_modelopt-0.21.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:1bab36c96598615f81efef4be148a23460c4a321289b5b789e33c6f179fdd932", size = 2816194 }, + { url = "https://files.pythonhosted.org/packages/fc/df/55326825e6c870a7d9af5417f734d0d0c1bb1277bb6b352889a4514cff5a/nvidia_modelopt-0.21.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f29fb9d39e6ba51e0f546a93ad9b275d888bd22f74c37547783a4708b9309821", size = 2715861 }, + { url = "https://files.pythonhosted.org/packages/3b/4a/955bc90607e944a8d08a5f77fe026a1102412195ae5e921bdfb4dda14dd7/nvidia_modelopt-0.21.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d2b8656057ed0d7505d4895542ecdaa6df7acbb6cebb6ff06182450f5ac5d570", size = 2813404 }, + { url = "https://files.pythonhosted.org/packages/39/ed/b5855c053c96990f7d38113c9694df4c42130710abb0eba7d0a3bef793c9/nvidia_modelopt-0.21.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dfcd53a4b46b4e6ad2744aa389b1dd68560d748f1645b4facf563e3b5b79a05", size = 2498183 }, + { url = "https://files.pythonhosted.org/packages/f5/fd/2b99bbb2d3bb714ac94d35fa32795f2c70814e3cc3aed1612cea41ca10de/nvidia_modelopt-0.21.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:6ba34f712a0c4333e3a37b4454d2f6e24d4dd45a99b5d7640a516f3d12ded8c8", size = 2590413 }, ] [package.optional-dependencies] @@ -678,6 +1151,7 @@ hf = [ { name = "accelerate", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "diffusers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "huggingface-hub", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "peft", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "transformers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] torch = [ @@ -714,6 +1188,85 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/cu126/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl" }, ] +[[package]] +name = "onnx" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "protobuf", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/54/0e385c26bf230d223810a9c7d06628d954008a5e5e4b73ee26ef02327282/onnx-1.17.0.tar.gz", hash = "sha256:48ca1a91ff73c1d5e3ea2eef20ae5d0e709bb8a2355ed798ffc2169753013fd3", size = 12165120 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/0d/831807a18db2a5e8f7813848c59272b904a4ef3939fe4d1288cbce9ea735/onnx-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d545335cb49d4d8c47cc803d3a805deb7ad5d9094dc67657d66e568610a36d7d", size = 15908420 }, + { url = "https://files.pythonhosted.org/packages/dd/5b/c4f95dbe652d14aeba9afaceb177e9ffc48ac3c03048dd3f872f26f07e34/onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3193a3672fc60f1a18c0f4c93ac81b761bc72fd8a6c2035fa79ff5969f07713e", size = 16046244 }, + { url = "https://files.pythonhosted.org/packages/7b/e3/cc80110e5996ca61878f7b4c73c7a286cd88918ff35eacb60dc75ab11ef5/onnx-1.17.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01a4b63d4e1d8ec3e2f069e7b798b2955810aa434f7361f01bc8ca08d69cce4", size = 15908949 }, + { url = "https://files.pythonhosted.org/packages/b1/2f/91092557ed478e323a2b4471e2081fdf88d1dd52ae988ceaf7db4e4506ff/onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a183c6178be001bf398260e5ac2c927dc43e7746e8638d6c05c20e321f8c949", size = 16048190 }, + { url = "https://files.pythonhosted.org/packages/f0/6c/f040652277f514ecd81b7251841f96caa5538365af7df07f86c6018cda2b/onnx-1.17.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d955ba2939878a520a97614bcf2e79c1df71b29203e8ced478fa78c9a9c63c2", size = 15907522 }, + { url = "https://files.pythonhosted.org/packages/3d/7c/67f4952d1b56b3f74a154b97d0dd0630d525923b354db117d04823b8b49b/onnx-1.17.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f3fb5cc4e2898ac5312a7dc03a65133dd2abf9a5e520e69afb880a7251ec97a", size = 16046307 }, + { url = "https://files.pythonhosted.org/packages/61/94/d753c230d56234dd01ad939590a2ed33221b57c61abe513ff6823a69af6e/onnx-1.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e19fd064b297f7773b4c1150f9ce6213e6d7d041d7a9201c0d348041009cdcd", size = 15908316 }, + { url = "https://files.pythonhosted.org/packages/3d/da/c19d0f20d310045f4701d75ecba4f765153251d48a32f27a5d6b0a7e3799/onnx-1.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8167295f576055158a966161f8ef327cb491c06ede96cc23392be6022071b6ed", size = 16046488 }, +] + +[[package]] +name = "onnx-graphsurgeon" +version = "0.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "onnx", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/20/93e7143af3a0b3b3d9f3306bfc46e55d0d307242b4c1bf36ff108460e5a3/onnx_graphsurgeon-0.5.2-py2.py3-none-any.whl", hash = "sha256:10c130d6129fdeee02945f8103b5b112e6fd4d9b356e2dd3e80f53e0ebee7b5c", size = 56430 }, +] + +[[package]] +name = "openai" +version = "1.54.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "distro", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "httpx", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "jiter", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pydantic", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "sniffio", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/d8/ff02ef5255bf1b0c75f6badced772b4fd5c29a517d2a0be6f4e0a65a3227/openai-1.54.3.tar.gz", hash = "sha256:7511b74eeb894ac0b0253dc71f087a15d2e4d71d22d0088767205143d880cca6", size = 313963 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/85/e7adeee84edd24c6cc119b2ccaaacd9579c6a2c7f72d05e936ea6b33594e/openai-1.54.3-py3-none-any.whl", hash = "sha256:f18dbaf09c50d70c4185b892a2a553f80681d1d866323a2da7f7be2f688615d5", size = 389619 }, +] + +[[package]] +name = "optimum" +version = "1.23.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coloredlogs", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "datasets", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "huggingface-hub", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "sympy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "transformers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f5/95/44eb569e2a70f9c63dd75f80fea8495eec464c29b988188ebcae940a6470/optimum-1.23.3.tar.gz", hash = "sha256:2089bd73d1232686473a80effd53800f8a8c385c02126e80d35c07227c1b9bf5", size = 341546 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/33/97cf226c47e4cf5a79159668732038cdd6c0199c72782d5b5a0db54f9a2d/optimum-1.23.3-py3-none-any.whl", hash = "sha256:ac34b497310e74e919e8eb3bc01cfea48bca304ade3e3ce8a7707d125120001a", size = 424070 }, +] + +[[package]] +name = "ordered-set" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/ca/bfac8bc689799bcca4157e0e0ced07e70ce125193fc2e166d2e685b7e2fe/ordered-set-4.1.0.tar.gz", hash = "sha256:694a8e44c87657c59292ede72891eb91d34131f6531463aab3009191c77364a8", size = 12826 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/55/af02708f230eb77084a299d7b08175cff006dea4f2721074b92cdb0296c0/ordered_set-4.1.0-py3-none-any.whl", hash = "sha256:046e1132c71fcf3330438a539928932caf51ddbc582496833e23de611de14562", size = 7634 }, +] + [[package]] name = "packaging" version = "24.1" @@ -723,6 +1276,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/aa/cc0199a5f0ad350994d660967a8efb233fe0416e4639146c089643407ce6/packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124", size = 53985 }, ] +[[package]] +name = "pandas" +version = "2.2.3" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +dependencies = [ + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "python-dateutil", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pytz", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tzdata", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, +] + [[package]] name = "parameterized" version = "0.9.0" @@ -750,6 +1328,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, ] +[[package]] +name = "peft" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "accelerate", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "huggingface-hub", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "packaging", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "psutil", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pyyaml", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "safetensors", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "transformers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/21/33/fb0c31eaa8162c01e9250b21aa65d46a5339f17a818a97c68391db2ff44b/peft-0.14.0.tar.gz", hash = "sha256:546d69af7b42f5ef715a3d3261ed818bc917ae6055e5d7e187ed3f2c76ad72dc", size = 411902 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/05/e58e3aaa36544d30a917814e336fc65a746f708e5874945e92999bc22fa3/peft-0.14.0-py3-none-any.whl", hash = "sha256:2f04f3a870c3baf30f15e7dcaa5dd70d3e54cfdd146d3c6c187735d3ae0a0700", size = 374831 }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -826,6 +1425,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "polygraphy" +version = "0.49.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/f5/a2b20c677c1a856cc9e08cd0b5a5105450ed5253e369e938ddd31d91c547/polygraphy-0.49.9-py2.py3-none-any.whl", hash = "sha256:62ae22825efdd3288222e5b1d2d791fe58e87844fcd848bcd1251fbce02ba956", size = 346910 }, +] + [[package]] name = "pre-commit" version = "3.8.0" @@ -854,6 +1461,81 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e8/23/22750c4b768f09386d1c3cc4337953e8936f48a888fa6dddfb669b2c9088/prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10", size = 386411 }, ] +[[package]] +name = "propcache" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/c8/2a13f78d82211490855b2fb303b6721348d0787fdd9a12ac46d99d3acde1/propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64", size = 41735 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/3d/31c9c29ee7192defc05aa4d01624fd85a41cf98e5922aaed206017329944/propcache-0.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9479aa06a793c5aeba49ce5c5692ffb51fcd9a7016e017d555d5e2b0045d212", size = 204809 }, + { url = "https://files.pythonhosted.org/packages/10/a1/e4050776f4797fc86140ac9a480d5dc069fbfa9d499fe5c5d2fa1ae71f07/propcache-0.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9631c5e8b5b3a0fda99cb0d29c18133bca1e18aea9effe55adb3da1adef80d3", size = 219109 }, + { url = "https://files.pythonhosted.org/packages/c9/c0/e7ae0df76343d5e107d81e59acc085cea5fd36a48aa53ef09add7503e888/propcache-0.2.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3156628250f46a0895f1f36e1d4fbe062a1af8718ec3ebeb746f1d23f0c5dc4d", size = 217368 }, + { url = "https://files.pythonhosted.org/packages/fc/e1/e0a2ed6394b5772508868a977d3238f4afb2eebaf9976f0b44a8d347ad63/propcache-0.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6fb63ae352e13748289f04f37868099e69dba4c2b3e271c46061e82c745634", size = 205124 }, + { url = "https://files.pythonhosted.org/packages/50/c1/e388c232d15ca10f233c778bbdc1034ba53ede14c207a72008de45b2db2e/propcache-0.2.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:887d9b0a65404929641a9fabb6452b07fe4572b269d901d622d8a34a4e9043b2", size = 195463 }, + { url = "https://files.pythonhosted.org/packages/0a/fd/71b349b9def426cc73813dbd0f33e266de77305e337c8c12bfb0a2a82bfb/propcache-0.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a96dc1fa45bd8c407a0af03b2d5218392729e1822b0c32e62c5bf7eeb5fb3958", size = 198358 }, + { url = "https://files.pythonhosted.org/packages/02/f2/d7c497cd148ebfc5b0ae32808e6c1af5922215fe38c7a06e4e722fe937c8/propcache-0.2.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a7e65eb5c003a303b94aa2c3852ef130230ec79e349632d030e9571b87c4698c", size = 195560 }, + { url = "https://files.pythonhosted.org/packages/bb/57/f37041bbe5e0dfed80a3f6be2612a3a75b9cfe2652abf2c99bef3455bbad/propcache-0.2.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:999779addc413181912e984b942fbcc951be1f5b3663cd80b2687758f434c583", size = 196895 }, + { url = "https://files.pythonhosted.org/packages/83/36/ae3cc3e4f310bff2f064e3d2ed5558935cc7778d6f827dce74dcfa125304/propcache-0.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:19a0f89a7bb9d8048d9c4370c9c543c396e894c76be5525f5e1ad287f1750ddf", size = 207124 }, + { url = "https://files.pythonhosted.org/packages/8c/c4/811b9f311f10ce9d31a32ff14ce58500458443627e4df4ae9c264defba7f/propcache-0.2.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1ac2f5fe02fa75f56e1ad473f1175e11f475606ec9bd0be2e78e4734ad575034", size = 210442 }, + { url = "https://files.pythonhosted.org/packages/18/dd/a1670d483a61ecac0d7fc4305d91caaac7a8fc1b200ea3965a01cf03bced/propcache-0.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:574faa3b79e8ebac7cb1d7930f51184ba1ccf69adfdec53a12f319a06030a68b", size = 203219 }, + { url = "https://files.pythonhosted.org/packages/03/7a/793aa12f0537b2e520bf09f4c6833706b63170a211ad042ca71cbf79d9cb/propcache-0.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b480c6a4e1138e1aa137c0079b9b6305ec6dcc1098a8ca5196283e8a49df95a9", size = 232136 }, + { url = "https://files.pythonhosted.org/packages/f1/38/b921b3168d72111769f648314100558c2ea1d52eb3d1ba7ea5c4aa6f9848/propcache-0.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d27b84d5880f6d8aa9ae3edb253c59d9f6642ffbb2c889b78b60361eed449787", size = 239706 }, + { url = "https://files.pythonhosted.org/packages/14/29/4636f500c69b5edea7786db3c34eb6166f3384b905665ce312a6e42c720c/propcache-0.2.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:857112b22acd417c40fa4595db2fe28ab900c8c5fe4670c7989b1c0230955465", size = 238531 }, + { url = "https://files.pythonhosted.org/packages/85/14/01fe53580a8e1734ebb704a3482b7829a0ef4ea68d356141cf0994d9659b/propcache-0.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf6c4150f8c0e32d241436526f3c3f9cbd34429492abddbada2ffcff506c51af", size = 231063 }, + { url = "https://files.pythonhosted.org/packages/33/5c/1d961299f3c3b8438301ccfbff0143b69afcc30c05fa28673cface692305/propcache-0.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d4cfda1d8ed687daa4bc0274fcfd5267873db9a5bc0418c2da19273040eeb7", size = 220134 }, + { url = "https://files.pythonhosted.org/packages/00/d0/ed735e76db279ba67a7d3b45ba4c654e7b02bc2f8050671ec365d8665e21/propcache-0.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c2f992c07c0fca81655066705beae35fc95a2fa7366467366db627d9f2ee097f", size = 220009 }, + { url = "https://files.pythonhosted.org/packages/75/90/ee8fab7304ad6533872fee982cfff5a53b63d095d78140827d93de22e2d4/propcache-0.2.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:4a571d97dbe66ef38e472703067021b1467025ec85707d57e78711c085984e54", size = 212199 }, + { url = "https://files.pythonhosted.org/packages/eb/ec/977ffaf1664f82e90737275873461695d4c9407d52abc2f3c3e24716da13/propcache-0.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:bb6178c241278d5fe853b3de743087be7f5f4c6f7d6d22a3b524d323eecec505", size = 214827 }, + { url = "https://files.pythonhosted.org/packages/57/48/031fb87ab6081764054821a71b71942161619549396224cbb242922525e8/propcache-0.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ad1af54a62ffe39cf34db1aa6ed1a1873bd548f6401db39d8e7cd060b9211f82", size = 228009 }, + { url = "https://files.pythonhosted.org/packages/1a/06/ef1390f2524850838f2390421b23a8b298f6ce3396a7cc6d39dedd4047b0/propcache-0.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e7048abd75fe40712005bcfc06bb44b9dfcd8e101dda2ecf2f5aa46115ad07ca", size = 231638 }, + { url = "https://files.pythonhosted.org/packages/38/2a/101e6386d5a93358395da1d41642b79c1ee0f3b12e31727932b069282b1d/propcache-0.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:160291c60081f23ee43d44b08a7e5fb76681221a8e10b3139618c5a9a291b84e", size = 222788 }, + { url = "https://files.pythonhosted.org/packages/7f/14/7ae06a6cf2a2f1cb382586d5a99efe66b0b3d0c6f9ac2f759e6f7af9d7cf/propcache-0.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:049324ee97bb67285b49632132db351b41e77833678432be52bdd0289c0e05e4", size = 241869 }, + { url = "https://files.pythonhosted.org/packages/cc/59/227a78be960b54a41124e639e2c39e8807ac0c751c735a900e21315f8c2b/propcache-0.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1cd9a1d071158de1cc1c71a26014dcdfa7dd3d5f4f88c298c7f90ad6f27bb46d", size = 247884 }, + { url = "https://files.pythonhosted.org/packages/84/58/f62b4ffaedf88dc1b17f04d57d8536601e4e030feb26617228ef930c3279/propcache-0.2.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98110aa363f1bb4c073e8dcfaefd3a5cea0f0834c2aab23dda657e4dab2f53b5", size = 248486 }, + { url = "https://files.pythonhosted.org/packages/1c/07/ebe102777a830bca91bbb93e3479cd34c2ca5d0361b83be9dbd93104865e/propcache-0.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:647894f5ae99c4cf6bb82a1bb3a796f6e06af3caa3d32e26d2350d0e3e3faf24", size = 243649 }, + { url = "https://files.pythonhosted.org/packages/ed/bc/4f7aba7f08f520376c4bb6a20b9a981a581b7f2e385fa0ec9f789bb2d362/propcache-0.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfd3223c15bebe26518d58ccf9a39b93948d3dcb3e57a20480dfdd315356baff", size = 229103 }, + { url = "https://files.pythonhosted.org/packages/fe/d5/04ac9cd4e51a57a96f78795e03c5a0ddb8f23ec098b86f92de028d7f2a6b/propcache-0.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d71264a80f3fcf512eb4f18f59423fe82d6e346ee97b90625f283df56aee103f", size = 226607 }, + { url = "https://files.pythonhosted.org/packages/e3/f0/24060d959ea41d7a7cc7fdbf68b31852331aabda914a0c63bdb0e22e96d6/propcache-0.2.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e73091191e4280403bde6c9a52a6999d69cdfde498f1fdf629105247599b57ec", size = 221153 }, + { url = "https://files.pythonhosted.org/packages/77/a7/3ac76045a077b3e4de4859a0753010765e45749bdf53bd02bc4d372da1a0/propcache-0.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3935bfa5fede35fb202c4b569bb9c042f337ca4ff7bd540a0aa5e37131659348", size = 222151 }, + { url = "https://files.pythonhosted.org/packages/e7/af/5e29da6f80cebab3f5a4dcd2a3240e7f56f2c4abf51cbfcc99be34e17f0b/propcache-0.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f508b0491767bb1f2b87fdfacaba5f7eddc2f867740ec69ece6d1946d29029a6", size = 233812 }, + { url = "https://files.pythonhosted.org/packages/8c/89/ebe3ad52642cc5509eaa453e9f4b94b374d81bae3265c59d5c2d98efa1b4/propcache-0.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:1672137af7c46662a1c2be1e8dc78cb6d224319aaa40271c9257d886be4363a6", size = 238829 }, + { url = "https://files.pythonhosted.org/packages/e9/2f/6b32f273fa02e978b7577159eae7471b3cfb88b48563b1c2578b2d7ca0bb/propcache-0.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b74c261802d3d2b85c9df2dfb2fa81b6f90deeef63c2db9f0e029a3cac50b518", size = 230704 }, + { url = "https://files.pythonhosted.org/packages/4e/3e/021b6cd86c0acc90d74784ccbb66808b0bd36067a1bf3e2deb0f3845f618/propcache-0.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba278acf14471d36316159c94a802933d10b6a1e117b8554fe0d0d9b75c9d536", size = 224819 }, + { url = "https://files.pythonhosted.org/packages/3c/57/c2fdeed1b3b8918b1770a133ba5c43ad3d78e18285b0c06364861ef5cc38/propcache-0.2.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4e6281aedfca15301c41f74d7005e6e3f4ca143584ba696ac69df4f02f40d629", size = 229625 }, + { url = "https://files.pythonhosted.org/packages/9d/81/70d4ff57bf2877b5780b466471bebf5892f851a7e2ca0ae7ffd728220281/propcache-0.2.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5b750a8e5a1262434fb1517ddf64b5de58327f1adc3524a5e44c2ca43305eb0b", size = 232934 }, + { url = "https://files.pythonhosted.org/packages/3c/b9/bb51ea95d73b3fb4100cb95adbd4e1acaf2cbb1fd1083f5468eeb4a099a8/propcache-0.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf72af5e0fb40e9babf594308911436c8efde3cb5e75b6f206c34ad18be5c052", size = 227361 }, + { url = "https://files.pythonhosted.org/packages/f1/20/3c6d696cd6fd70b29445960cc803b1851a1131e7a2e4ee261ee48e002bcd/propcache-0.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2d0a12018b04f4cb820781ec0dffb5f7c7c1d2a5cd22bff7fb055a2cb19ebce", size = 213904 }, + { url = "https://files.pythonhosted.org/packages/a1/cb/1593bfc5ac6d40c010fa823f128056d6bc25b667f5393781e37d62f12005/propcache-0.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e800776a79a5aabdb17dcc2346a7d66d0777e942e4cd251defeb084762ecd17d", size = 212632 }, + { url = "https://files.pythonhosted.org/packages/6d/5c/e95617e222be14a34c709442a0ec179f3207f8a2b900273720501a70ec5e/propcache-0.2.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4160d9283bd382fa6c0c2b5e017acc95bc183570cd70968b9202ad6d8fc48dce", size = 207897 }, + { url = "https://files.pythonhosted.org/packages/8e/3b/56c5ab3dc00f6375fbcdeefdede5adf9bee94f1fab04adc8db118f0f9e25/propcache-0.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:30b43e74f1359353341a7adb783c8f1b1c676367b011709f466f42fda2045e95", size = 208118 }, + { url = "https://files.pythonhosted.org/packages/86/25/d7ef738323fbc6ebcbce33eb2a19c5e07a89a3df2fded206065bd5e868a9/propcache-0.2.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:58791550b27d5488b1bb52bc96328456095d96206a250d28d874fafe11b3dfaf", size = 217851 }, + { url = "https://files.pythonhosted.org/packages/b3/77/763e6cef1852cf1ba740590364ec50309b89d1c818e3256d3929eb92fabf/propcache-0.2.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:0f022d381747f0dfe27e99d928e31bc51a18b65bb9e481ae0af1380a6725dd1f", size = 222630 }, + { url = "https://files.pythonhosted.org/packages/4f/e9/0f86be33602089c701696fbed8d8c4c07b6ee9605c5b7536fd27ed540c5b/propcache-0.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:297878dc9d0a334358f9b608b56d02e72899f3b8499fc6044133f0d319e2ec30", size = 216269 }, + { url = "https://files.pythonhosted.org/packages/22/59/6fe80a3fe7720f715f2c0f6df250dacbd7cad42832410dbd84c719c52f78/propcache-0.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5eee736daafa7af6d0a2dc15cc75e05c64f37fc37bafef2e00d77c14171c2097", size = 207792 }, + { url = "https://files.pythonhosted.org/packages/4a/68/584cd51dd8f4d0f5fff5b128ce0cdb257cde903898eecfb92156bbc2c780/propcache-0.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7a31fc1e1bd362874863fdeed71aed92d348f5336fd84f2197ba40c59f061bd", size = 223280 }, + { url = "https://files.pythonhosted.org/packages/85/cb/4c3528460c41e61b06ec3f970c0f89f87fa21f63acac8642ed81a886c164/propcache-0.2.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba4cfa1052819d16699e1d55d18c92b6e094d4517c41dd231a8b9f87b6fa681", size = 221293 }, + { url = "https://files.pythonhosted.org/packages/69/c0/560e050aa6d31eeece3490d1174da508f05ab27536dfc8474af88b97160a/propcache-0.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f089118d584e859c62b3da0892b88a83d611c2033ac410e929cb6754eec0ed16", size = 208259 }, + { url = "https://files.pythonhosted.org/packages/0c/87/d6c86a77632eb1ba86a328e3313159f246e7564cb5951e05ed77555826a0/propcache-0.2.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:781e65134efaf88feb447e8c97a51772aa75e48b794352f94cb7ea717dedda0d", size = 198632 }, + { url = "https://files.pythonhosted.org/packages/3a/2b/3690ea7b662dc762ab7af5f3ef0e2d7513c823d193d7b2a1b4cda472c2be/propcache-0.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:31f5af773530fd3c658b32b6bdc2d0838543de70eb9a2156c03e410f7b0d3aae", size = 203516 }, + { url = "https://files.pythonhosted.org/packages/4d/b5/afe716c16c23c77657185c257a41918b83e03993b6ccdfa748e5e7d328e9/propcache-0.2.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:a7a078f5d37bee6690959c813977da5291b24286e7b962e62a94cec31aa5188b", size = 199402 }, + { url = "https://files.pythonhosted.org/packages/a4/c0/2d2df3aa7f8660d0d4cc4f1e00490c48d5958da57082e70dea7af366f876/propcache-0.2.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:cea7daf9fc7ae6687cf1e2c049752f19f146fdc37c2cc376e7d0032cf4f25347", size = 200528 }, + { url = "https://files.pythonhosted.org/packages/21/c8/65ac9142f5e40c8497f7176e71d18826b09e06dd4eb401c9a4ee41aa9c74/propcache-0.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b3489ff1ed1e8315674d0775dc7d2195fb13ca17b3808721b54dbe9fd020faf", size = 211254 }, + { url = "https://files.pythonhosted.org/packages/09/e4/edb70b447a1d8142df51ec7511e84aa64d7f6ce0a0fdf5eb55363cdd0935/propcache-0.2.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9403db39be1393618dd80c746cb22ccda168efce239c73af13c3763ef56ffc04", size = 214589 }, + { url = "https://files.pythonhosted.org/packages/cb/02/817f309ec8d8883287781d6d9390f80b14db6e6de08bc659dfe798a825c2/propcache-0.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5d97151bc92d2b2578ff7ce779cdb9174337390a535953cbb9452fb65164c587", size = 207283 }, + { url = "https://files.pythonhosted.org/packages/41/b6/c5319caea262f4821995dca2107483b94a3345d4607ad797c76cb9c36bcc/propcache-0.2.1-py3-none-any.whl", hash = "sha256:52277518d6aae65536e9cea52d4e7fd2f7a66f4aa2d30ed3f2fcea620ace3c54", size = 11818 }, +] + +[[package]] +name = "protobuf" +version = "5.29.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/73/4e6295c1420a9d20c9c351db3a36109b4c9aa601916cb7c6871e3196a1ca/protobuf-5.29.2.tar.gz", hash = "sha256:b2cc8e8bb7c9326996f0e160137b0861f1a82162502658df2951209d0cb0309e", size = 424901 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/20/38fc33b60dcfb380507b99494aebe8c34b68b8ac7d32808c4cebda3f6f6b/protobuf-5.29.2-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:494229ecd8c9009dd71eda5fd57528395d1eacdf307dbece6c12ad0dd09e912e", size = 319562 }, + { url = "https://files.pythonhosted.org/packages/90/4d/c3d61e698e0e41d926dbff6aa4e57428ab1a6fc3b5e1deaa6c9ec0fd45cf/protobuf-5.29.2-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:b6b0d416bbbb9d4fbf9d0561dbfc4e324fd522f61f7af0fe0f282ab67b22477e", size = 319662 }, + { url = "https://files.pythonhosted.org/packages/f3/fd/c7924b4c2a1c61b8f4b64edd7a31ffacf63432135a2606f03a2f0d75a750/protobuf-5.29.2-py3-none-any.whl", hash = "sha256:fde4554c0e578a5a0bcc9a276339594848d1e89f9ea47b4427c80e5d72f90181", size = 172539 }, +] + [[package]] name = "psutil" version = "6.0.0" @@ -891,6 +1573,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, ] +[[package]] +name = "pyarrow" +version = "18.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7f/7b/640785a9062bb00314caa8a387abce547d2a420cf09bd6c715fe659ccffb/pyarrow-18.1.0.tar.gz", hash = "sha256:9386d3ca9c145b5539a1cfc75df07757dff870168c959b473a0bccbc3abc8c73", size = 1118671 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/2a/526545a7464b5fb2fa6e2c4bad16ca90e59e1843025c534fd907b7f73e5a/pyarrow-18.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f443122c8e31f4c9199cb23dca29ab9427cef990f283f80fe15b8e124bcc49b", size = 39213905 }, + { url = "https://files.pythonhosted.org/packages/8a/77/4b3fab91a30e19e233e738d0c5eca5a8f6dd05758bc349a2ca262c65de79/pyarrow-18.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a03da7f2758645d17b7b4f83c8bffeae5bbb7f974523fe901f36288d2eab71", size = 40128881 }, + { url = "https://files.pythonhosted.org/packages/aa/e2/a88e16c5e45e562449c52305bd3bc2f9d704295322d3434656e7ccac1444/pyarrow-18.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ba17845efe3aa358ec266cf9cc2800fa73038211fb27968bfa88acd09261a470", size = 38627517 }, + { url = "https://files.pythonhosted.org/packages/6d/84/8037c20005ccc7b869726465be0957bd9c29cfc88612962030f08292ad06/pyarrow-18.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:3c35813c11a059056a22a3bef520461310f2f7eea5c8a11ef9de7062a23f8d56", size = 40060187 }, + { url = "https://files.pythonhosted.org/packages/75/7e/332055ac913373e89256dce9d14b7708f55f7bd5be631456c897f0237738/pyarrow-18.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f39a2e0ed32a0970e4e46c262753417a60c43a3246972cfc2d3eb85aedd01b21", size = 39212135 }, + { url = "https://files.pythonhosted.org/packages/8c/64/5099cdb325828722ef7ffeba9a4696f238eb0cdeae227f831c2d77fcf1bd/pyarrow-18.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e31e9417ba9c42627574bdbfeada7217ad8a4cbbe45b9d6bdd4b62abbca4c6f6", size = 40125195 }, + { url = "https://files.pythonhosted.org/packages/83/88/1938d783727db1b178ff71bc6a6143d7939e406db83a9ec23cad3dad325c/pyarrow-18.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:01c034b576ce0eef554f7c3d8c341714954be9b3f5d5bc7117006b85fcf302fe", size = 38641884 }, + { url = "https://files.pythonhosted.org/packages/5e/b5/9e14e9f7590e0eaa435ecea84dabb137284a4dbba7b3c337b58b65b76d95/pyarrow-18.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f266a2c0fc31995a06ebd30bcfdb7f615d7278035ec5b1cd71c48d56daaf30b0", size = 40076877 }, + { url = "https://files.pythonhosted.org/packages/68/f9/29fb659b390312a7345aeb858a9d9c157552a8852522f2c8bad437c29c0a/pyarrow-18.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:543ad8459bc438efc46d29a759e1079436290bd583141384c6f7a1068ed6f992", size = 39203624 }, + { url = "https://files.pythonhosted.org/packages/6e/f6/19360dae44200e35753c5c2889dc478154cd78e61b1f738514c9f131734d/pyarrow-18.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0743e503c55be0fdb5c08e7d44853da27f19dc854531c0570f9f394ec9671d54", size = 40139341 }, + { url = "https://files.pythonhosted.org/packages/bb/e6/9b3afbbcf10cc724312e824af94a2e993d8ace22994d823f5c35324cebf5/pyarrow-18.1.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:d4b3d2a34780645bed6414e22dda55a92e0fcd1b8a637fba86800ad737057e33", size = 38618629 }, + { url = "https://files.pythonhosted.org/packages/3a/2e/3b99f8a3d9e0ccae0e961978a0d0089b25fb46ebbcfb5ebae3cca179a5b3/pyarrow-18.1.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c52f81aa6f6575058d8e2c782bf79d4f9fdc89887f16825ec3a66607a5dd8e30", size = 40078661 }, + { url = "https://files.pythonhosted.org/packages/41/d7/ed85001edfb96200ff606943cff71d64f91926ab42828676c0fc0db98963/pyarrow-18.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acb7564204d3c40babf93a05624fc6a8ec1ab1def295c363afc40b0c9e66c191", size = 39194527 }, + { url = "https://files.pythonhosted.org/packages/59/16/35e28eab126342fa391593415d79477e89582de411bb95232f28b131a769/pyarrow-18.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74de649d1d2ccb778f7c3afff6085bd5092aed4c23df9feeb45dd6b16f3811aa", size = 40131443 }, + { url = "https://files.pythonhosted.org/packages/0c/95/e855880614c8da20f4cd74fa85d7268c725cf0013dc754048593a38896a0/pyarrow-18.1.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f96bd502cb11abb08efea6dab09c003305161cb6c9eafd432e35e76e7fa9b90c", size = 38608750 }, + { url = "https://files.pythonhosted.org/packages/54/9d/f253554b1457d4fdb3831b7bd5f8f00f1795585a606eabf6fec0a58a9c38/pyarrow-18.1.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:36ac22d7782554754a3b50201b607d553a8d71b78cdf03b33c1125be4b52397c", size = 40066690 }, + { url = "https://files.pythonhosted.org/packages/5e/e3/3b16c3190f3d71d3b10f6758d2d5f7779ef008c4fd367cedab3ed178a9f7/pyarrow-18.1.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aebc13a11ed3032d8dd6e7171eb6e86d40d67a5639d96c35142bd568b9299324", size = 39119106 }, + { url = "https://files.pythonhosted.org/packages/1d/d6/5d704b0d25c3c79532f8c0639f253ec2803b897100f64bcb3f53ced236e5/pyarrow-18.1.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6cf5c05f3cee251d80e98726b5c7cc9f21bab9e9783673bac58e6dfab57ecc8", size = 40090940 }, + { url = "https://files.pythonhosted.org/packages/37/29/366bc7e588220d74ec00e497ac6710c2833c9176f0372fe0286929b2d64c/pyarrow-18.1.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:11b676cd410cf162d3f6a70b43fb9e1e40affbc542a1e9ed3681895f2962d3d9", size = 38548177 }, + { url = "https://files.pythonhosted.org/packages/c8/11/fabf6ecabb1fe5b7d96889228ca2a9158c4c3bb732e3b8ee3f7f6d40b703/pyarrow-18.1.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:b76130d835261b38f14fc41fdfb39ad8d672afb84c447126b84d5472244cfaba", size = 40043567 }, + { url = "https://files.pythonhosted.org/packages/84/c9/62ef9c6281c0e5b4ee1afa9d7bd556e72e06da6706b7906c32c15e69b3d6/pyarrow-18.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f97b31b4c4e21ff58c6f330235ff893cc81e23da081b1a4b1c982075e0ed4e9", size = 39226870 }, + { url = "https://files.pythonhosted.org/packages/b2/99/a6e89e71655a38475e76b060777c8bf69c078b772bec3b7daf7361440f05/pyarrow-18.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a4813cb8ecf1809871fd2d64a8eff740a1bd3691bbe55f01a3cf6c5ec869754", size = 40139114 }, + { url = "https://files.pythonhosted.org/packages/64/a9/06d79923890682e4fe7a16524abee307407008a413115354aaf3226b8410/pyarrow-18.1.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:05a5636ec3eb5cc2a36c6edb534a38ef57b2ab127292a716d00eabb887835f1e", size = 38639231 }, + { url = "https://files.pythonhosted.org/packages/3b/8c/4c3ed19026a00740b81fe1c87f3ff235b2763a0a1ddf5711a9d026b775ce/pyarrow-18.1.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:73eeed32e724ea3568bb06161cad5fa7751e45bc2228e33dcb10c614044165c7", size = 40070949 }, +] + +[[package]] +name = "pycparser" +version = "2.22" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, +] + [[package]] name = "pydantic" version = "2.9.2" @@ -975,6 +1698,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, ] +[[package]] +name = "pynvml" +version = "12.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-ml-py", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/6f/6b5880ed0239e85b9a39aed103b65b2ef81425beef9f45e5c035bf008330/pynvml-12.0.0.tar.gz", hash = "sha256:299ce2451a6a17e6822d6faee750103e25b415f06f59abb8db65d30f794166f5", size = 33636 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/df/f7cf07a65a96dd11d71f346f9c2863accdd4784da83af7181b067d556cbc/pynvml-12.0.0-py3-none-any.whl", hash = "sha256:fdff84b62a27dbe98e08e1a647eb77342bef1aebe0878bcd15e99a83fcbecb9e", size = 26560 }, +] + +[[package]] +name = "pyproject-hooks" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/82/28175b2414effca1cdac8dc99f76d660e7a4fb0ceefa4b4ab8f5f6742925/pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8", size = 19228 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216 }, +] + [[package]] name = "pytest" version = "8.3.3" @@ -1026,6 +1770,65 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, ] +[[package]] +name = "pyzmq" +version = "26.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "(implementation_name == 'pypy' and sys_platform == 'linux') or (implementation_name == 'pypy' and sys_platform == 'windows')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/05/bed626b9f7bb2322cdbbf7b4bd8f54b1b617b0d2ab2d3547d6e39428a48e/pyzmq-26.2.0.tar.gz", hash = "sha256:070672c258581c8e4f640b5159297580a9974b026043bd4ab0470be9ed324f1f", size = 271975 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/09/b51b6683fde5ca04593a57bbe81788b6b43114d8f8ee4e80afc991e14760/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89289a5ee32ef6c439086184529ae060c741334b8970a6855ec0b6ad3ff28764", size = 673199 }, + { url = "https://files.pythonhosted.org/packages/c9/78/486f3e2e824f3a645238332bf5a4c4b4477c3063033a27c1e4052358dee2/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5506f06d7dc6ecf1efacb4a013b1f05071bb24b76350832c96449f4a2d95091c", size = 911762 }, + { url = "https://files.pythonhosted.org/packages/5e/3b/2eb1667c9b866f53e76ee8b0c301b0469745a23bd5a87b7ee3d5dd9eb6e5/pyzmq-26.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ea039387c10202ce304af74def5021e9adc6297067f3441d348d2b633e8166a", size = 868773 }, + { url = "https://files.pythonhosted.org/packages/16/29/ca99b4598a9dc7e468b5417eda91f372b595be1e3eec9b7cbe8e5d3584e8/pyzmq-26.2.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a2224fa4a4c2ee872886ed00a571f5e967c85e078e8e8c2530a2fb01b3309b88", size = 868834 }, + { url = "https://files.pythonhosted.org/packages/ad/e5/9efaeb1d2f4f8c50da04144f639b042bc52869d3a206d6bf672ab3522163/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:28ad5233e9c3b52d76196c696e362508959741e1a005fb8fa03b51aea156088f", size = 1202861 }, + { url = "https://files.pythonhosted.org/packages/c3/62/c721b5608a8ac0a69bb83cbb7d07a56f3ff00b3991a138e44198a16f94c7/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:1c17211bc037c7d88e85ed8b7d8f7e52db6dc8eca5590d162717c654550f7282", size = 1515304 }, + { url = "https://files.pythonhosted.org/packages/87/84/e8bd321aa99b72f48d4606fc5a0a920154125bd0a4608c67eab742dab087/pyzmq-26.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b8f86dd868d41bea9a5f873ee13bf5551c94cf6bc51baebc6f85075971fe6eea", size = 1414712 }, + { url = "https://files.pythonhosted.org/packages/e1/bf/c67fd638c2f9fbbab8090a3ee779370b97c82b84cc12d0c498b285d7b2c0/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77eb0968da535cba0470a5165468b2cac7772cfb569977cff92e240f57e31bef", size = 673129 }, + { url = "https://files.pythonhosted.org/packages/86/94/99085a3f492aa538161cbf27246e8886ff850e113e0c294a5b8245f13b52/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ace4f71f1900a548f48407fc9be59c6ba9d9aaf658c2eea6cf2779e72f9f317", size = 910107 }, + { url = "https://files.pythonhosted.org/packages/31/1d/346809e8a9b999646d03f21096428453465b1bca5cd5c64ecd048d9ecb01/pyzmq-26.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92a78853d7280bffb93df0a4a6a2498cba10ee793cc8076ef797ef2f74d107cf", size = 867960 }, + { url = "https://files.pythonhosted.org/packages/ab/68/6fb6ae5551846ad5beca295b7bca32bf0a7ce19f135cb30e55fa2314e6b6/pyzmq-26.2.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:689c5d781014956a4a6de61d74ba97b23547e431e9e7d64f27d4922ba96e9d6e", size = 869204 }, + { url = "https://files.pythonhosted.org/packages/0f/f9/18417771dee223ccf0f48e29adf8b4e25ba6d0e8285e33bcbce078070bc3/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0aca98bc423eb7d153214b2df397c6421ba6373d3397b26c057af3c904452e37", size = 1203351 }, + { url = "https://files.pythonhosted.org/packages/e0/46/f13e67fe0d4f8a2315782cbad50493de6203ea0d744610faf4d5f5b16e90/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1f3496d76b89d9429a656293744ceca4d2ac2a10ae59b84c1da9b5165f429ad3", size = 1514204 }, + { url = "https://files.pythonhosted.org/packages/50/11/ddcf7343b7b7a226e0fc7b68cbf5a5bb56291fac07f5c3023bb4c319ebb4/pyzmq-26.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5c2b3bfd4b9689919db068ac6c9911f3fcb231c39f7dd30e3138be94896d18e6", size = 1414339 }, + { url = "https://files.pythonhosted.org/packages/4f/ef/5a23ec689ff36d7625b38d121ef15abfc3631a9aecb417baf7a4245e4124/pyzmq-26.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55cf66647e49d4621a7e20c8d13511ef1fe1efbbccf670811864452487007e08", size = 665923 }, + { url = "https://files.pythonhosted.org/packages/ae/61/d436461a47437d63c6302c90724cf0981883ec57ceb6073873f32172d676/pyzmq-26.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4661c88db4a9e0f958c8abc2b97472e23061f0bc737f6f6179d7a27024e1faa5", size = 903400 }, + { url = "https://files.pythonhosted.org/packages/47/42/fc6d35ecefe1739a819afaf6f8e686f7f02a4dd241c78972d316f403474c/pyzmq-26.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea7f69de383cb47522c9c208aec6dd17697db7875a4674c4af3f8cfdac0bdeae", size = 860034 }, + { url = "https://files.pythonhosted.org/packages/07/3b/44ea6266a6761e9eefaa37d98fabefa112328808ac41aa87b4bbb668af30/pyzmq-26.2.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:7f98f6dfa8b8ccaf39163ce872bddacca38f6a67289116c8937a02e30bbe9711", size = 860579 }, + { url = "https://files.pythonhosted.org/packages/38/6f/4df2014ab553a6052b0e551b37da55166991510f9e1002c89cab7ce3b3f2/pyzmq-26.2.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e3e0210287329272539eea617830a6a28161fbbd8a3271bf4150ae3e58c5d0e6", size = 1196246 }, + { url = "https://files.pythonhosted.org/packages/38/9d/ee240fc0c9fe9817f0c9127a43238a3e28048795483c403cc10720ddef22/pyzmq-26.2.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6b274e0762c33c7471f1a7471d1a2085b1a35eba5cdc48d2ae319f28b6fc4de3", size = 1507441 }, + { url = "https://files.pythonhosted.org/packages/85/4f/01711edaa58d535eac4a26c294c617c9a01f09857c0ce191fd574d06f359/pyzmq-26.2.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:29c6a4635eef69d68a00321e12a7d2559fe2dfccfa8efae3ffb8e91cd0b36a8b", size = 1406498 }, + { url = "https://files.pythonhosted.org/packages/68/ba/f4280c58ff71f321602a6e24fd19879b7e79793fb8ab14027027c0fb58ef/pyzmq-26.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d80b1dd99c1942f74ed608ddb38b181b87476c6a966a88a950c7dee118fdf50", size = 665485 }, + { url = "https://files.pythonhosted.org/packages/77/b5/c987a5c53c7d8704216f29fc3d810b32f156bcea488a940e330e1bcbb88d/pyzmq-26.2.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c997098cc65e3208eca09303630e84d42718620e83b733d0fd69543a9cab9cb", size = 903484 }, + { url = "https://files.pythonhosted.org/packages/29/c9/07da157d2db18c72a7eccef8e684cefc155b712a88e3d479d930aa9eceba/pyzmq-26.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ad1bc8d1b7a18497dda9600b12dc193c577beb391beae5cd2349184db40f187", size = 859981 }, + { url = "https://files.pythonhosted.org/packages/43/09/e12501bd0b8394b7d02c41efd35c537a1988da67fc9c745cae9c6c776d31/pyzmq-26.2.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:bea2acdd8ea4275e1278350ced63da0b166421928276c7c8e3f9729d7402a57b", size = 860334 }, + { url = "https://files.pythonhosted.org/packages/eb/ff/f5ec1d455f8f7385cc0a8b2acd8c807d7fade875c14c44b85c1bddabae21/pyzmq-26.2.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:23f4aad749d13698f3f7b64aad34f5fc02d6f20f05999eebc96b89b01262fb18", size = 1196179 }, + { url = "https://files.pythonhosted.org/packages/ec/8a/bb2ac43295b1950fe436a81fc5b298be0b96ac76fb029b514d3ed58f7b27/pyzmq-26.2.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:a4f96f0d88accc3dbe4a9025f785ba830f968e21e3e2c6321ccdfc9aef755115", size = 1507668 }, + { url = "https://files.pythonhosted.org/packages/a9/49/dbc284ebcfd2dca23f6349227ff1616a7ee2c4a35fe0a5d6c3deff2b4fed/pyzmq-26.2.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ced65e5a985398827cc9276b93ef6dfabe0273c23de8c7931339d7e141c2818e", size = 1406539 }, + { url = "https://files.pythonhosted.org/packages/82/86/3fe917870e15ee1c3ad48229a2a64458e36036e64b4afa9659045d82bfa8/pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:794a4562dcb374f7dbbfb3f51d28fb40123b5a2abadee7b4091f93054909add5", size = 653242 }, + { url = "https://files.pythonhosted.org/packages/50/2d/242e7e6ef6c8c19e6cb52d095834508cd581ffb925699fd3c640cdc758f1/pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aee22939bb6075e7afededabad1a56a905da0b3c4e3e0c45e75810ebe3a52672", size = 888404 }, + { url = "https://files.pythonhosted.org/packages/ac/11/7270566e1f31e4ea73c81ec821a4b1688fd551009a3d2bab11ec66cb1e8f/pyzmq-26.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae90ff9dad33a1cfe947d2c40cb9cb5e600d759ac4f0fd22616ce6540f72797", size = 845858 }, + { url = "https://files.pythonhosted.org/packages/91/d5/72b38fbc69867795c8711bdd735312f9fef1e3d9204e2f63ab57085434b9/pyzmq-26.2.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:43a47408ac52647dfabbc66a25b05b6a61700b5165807e3fbd40063fcaf46386", size = 847375 }, + { url = "https://files.pythonhosted.org/packages/dd/9a/10ed3c7f72b4c24e719c59359fbadd1a27556a28b36cdf1cd9e4fb7845d5/pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:25bf2374a2a8433633c65ccb9553350d5e17e60c8eb4de4d92cc6bd60f01d306", size = 1183489 }, + { url = "https://files.pythonhosted.org/packages/72/2d/8660892543fabf1fe41861efa222455811adac9f3c0818d6c3170a1153e3/pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_i686.whl", hash = "sha256:007137c9ac9ad5ea21e6ad97d3489af654381324d5d3ba614c323f60dab8fae6", size = 1492932 }, + { url = "https://files.pythonhosted.org/packages/7b/d6/32fd69744afb53995619bc5effa2a405ae0d343cd3e747d0fbc43fe894ee/pyzmq-26.2.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:470d4a4f6d48fb34e92d768b4e8a5cc3780db0d69107abf1cd7ff734b9766eb0", size = 1392485 }, + { url = "https://files.pythonhosted.org/packages/ed/69/0529b59ac667ea8bfe8796ac71796b688fbb42ff78e06525dabfed3bc7ae/pyzmq-26.2.0-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d049df610ac811dcffdc147153b414147428567fbbc8be43bb8885f04db39d98", size = 908009 }, + { url = "https://files.pythonhosted.org/packages/6e/bd/3ff3e1172f12f55769793a3a334e956ec2886805ebfb2f64756b6b5c6a1a/pyzmq-26.2.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05590cdbc6b902101d0e65d6a4780af14dc22914cc6ab995d99b85af45362cc9", size = 862078 }, + { url = "https://files.pythonhosted.org/packages/c3/ec/ab13585c3a1f48e2874253844c47b194d56eb25c94718691349c646f336f/pyzmq-26.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c811cfcd6a9bf680236c40c6f617187515269ab2912f3d7e8c0174898e2519db", size = 673756 }, + { url = "https://files.pythonhosted.org/packages/1e/be/febcd4b04dd50ee6d514dfbc33a3d5d9cb38ec9516e02bbfc929baa0f141/pyzmq-26.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6835dd60355593de10350394242b5757fbbd88b25287314316f266e24c61d073", size = 1203684 }, + { url = "https://files.pythonhosted.org/packages/16/28/304150e71afd2df3b82f52f66c0d8ab9ac6fe1f1ffdf92bad4c8cc91d557/pyzmq-26.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc6bee759a6bddea5db78d7dcd609397449cb2d2d6587f48f3ca613b19410cfc", size = 1515864 }, + { url = "https://files.pythonhosted.org/packages/18/89/8d48d8cd505c12a1f5edee597cc32ffcedc65fd8d2603aebaaedc38a7041/pyzmq-26.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c530e1eecd036ecc83c3407f77bb86feb79916d4a33d11394b8234f3bd35b940", size = 1415383 }, + { url = "https://files.pythonhosted.org/packages/77/8f/6ce54f8979a01656e894946db6299e2273fcee21c8e5fa57c6295ef11f57/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b435f2753621cd36e7c1762156815e21c985c72b19135dac43a7f4f31d28dd1", size = 565701 }, + { url = "https://files.pythonhosted.org/packages/ee/1c/bf8cd66730a866b16db8483286078892b7f6536f8c389fb46e4beba0a970/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:160c7e0a5eb178011e72892f99f918c04a131f36056d10d9c1afb223fc952c2d", size = 794312 }, + { url = "https://files.pythonhosted.org/packages/71/43/91fa4ff25bbfdc914ab6bafa0f03241d69370ef31a761d16bb859f346582/pyzmq-26.2.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4a71d5d6e7b28a47a394c0471b7e77a0661e2d651e7ae91e0cab0a587859ca", size = 752775 }, + { url = "https://files.pythonhosted.org/packages/da/f2/8054574d77c269c31d055d4daf3d8407adf61ea384a50c8d14b158551d09/pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35cffef589bcdc587d06f9149f8d5e9e8859920a071df5a2671de2213bef592a", size = 565698 }, + { url = "https://files.pythonhosted.org/packages/77/21/c3ad93236d1d60eea10b67528f55e7db115a9d32e2bf163fcf601f85e9cc/pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18c8dc3b7468d8b4bdf60ce9d7141897da103c7a4690157b32b60acb45e333e6", size = 794307 }, + { url = "https://files.pythonhosted.org/packages/6a/49/e95b491724500fcb760178ce8db39b923429e328e57bcf9162e32c2c187c/pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7133d0a1677aec369d67dd78520d3fa96dd7f3dcec99d66c1762870e5ea1a50a", size = 752769 }, + { url = "https://files.pythonhosted.org/packages/9b/a9/50c9c06762b30792f71aaad8d1886748d39c4bffedc1171fbc6ad2b92d67/pyzmq-26.2.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:6a96179a24b14fa6428cbfc08641c779a53f8fcec43644030328f44034c7f1f4", size = 751338 }, +] + [[package]] name = "regex" version = "2024.9.11" @@ -1172,6 +1975,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8d/02/1165905f14962174e6569076bcc3315809ae1291ed14de6448cc151eedfd/scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c", size = 38845637 }, ] +[[package]] +name = "sentencepiece" +version = "0.2.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, +] + [[package]] name = "setuptools" version = "75.1.0" @@ -1189,6 +2011,24 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/six-1.16.0-py2.py3-none-any.whl" }, ] +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -1203,6 +2043,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "starlette" +version = "0.41.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "typing-extensions", marker = "(python_full_version < '3.10' and sys_platform == 'linux') or (python_full_version < '3.10' and sys_platform == 'windows')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1a/4c/9b5764bd22eec91c4039ef4c55334e9187085da2d8a2df7bd570869aae18/starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835", size = 2574159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/00/2b325970b3060c7cecebab6d295afe763365822b1306a12eeab198f74323/starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7", size = 73225 }, +] + +[[package]] +name = "strenum" +version = "0.4.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/ad/430fb60d90e1d112a62ff57bdd1f286ec73a2a0331272febfddd21f330e1/StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff", size = 23384 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/69/297302c5f5f59c862faa31e6cb9a4cd74721cd1e052b38e464c5b402df8b/StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659", size = 8851 }, +] + [[package]] name = "sympy" version = "1.13.1" @@ -1214,6 +2076,15 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/sympy-1.13.1-py3-none-any.whl" }, ] +[[package]] +name = "tensorrt" +version = "10.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tensorrt-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/a6/791f1e15e708be8c9d958236c1659d1d27fe8a31fc84f011f9920c446d7a/tensorrt-10.7.0.tar.gz", hash = "sha256:2c74ee72c450ef055e894e04f27510094f81e575a284af9d4b4b06a4369d0e2d", size = 16401 } + [[package]] name = "tensorrt-cu12" version = "10.7.0" @@ -1244,6 +2115,54 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/e7/72/69dd1a6ae202c826f61dcd5bc3019449ed1c961d0b60804311d2493403c8/tensorrt_cu12_libs-10.7.0.tar.gz", hash = "sha256:45f6419adf59ad6532d44f945f1efe79a2975089bf7f92a4374945df40951c02", size = 699 } +[[package]] +name = "tensorrt-llm" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "accelerate", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "aenum", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "build", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "click", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "click-option-group", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "colored", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "cuda-python", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "diffusers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "evaluate", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "fastapi", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "h5py", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "httpx", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "lark", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "mpi4py", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "mpmath", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "nvidia-modelopt", extra = ["torch"], marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "nvidia-nccl-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "onnx", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "onnx-graphsurgeon", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "openai", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "optimum", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "ordered-set", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pandas", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pillow", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "polygraphy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "psutil", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pulp", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pydantic", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pynvml", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "pyzmq", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "sentencepiece", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "setuptools", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "strenum", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "tensorrt", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "torchvision", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "transformers", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "uvicorn", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "wheel", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/66/91b2582f56f370b54be501bd9fb8e4b90bdc16fe400173b6aa04a2ab7c70/tensorrt_llm-0.16.0.tar.gz", hash = "sha256:f0337f979b85f2579ad6fd3d3062cb96b1ccc3876e635ca8070617572ec12008", size = 1240 } + [[package]] name = "tokenizers" version = "0.19.1" @@ -1357,6 +2276,9 @@ dependencies = [ ] [package.optional-dependencies] +distributed = [ + { name = "tensorrt-llm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] jupyter = [ { name = "rich", extra = ["jupyter"], marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] @@ -1389,7 +2311,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "numpy" }, - { name = "nvidia-modelopt", extras = ["deploy", "hf", "torch"], marker = "extra == 'quantization'", specifier = "~=0.17.0" }, + { name = "nvidia-modelopt", extras = ["deploy", "hf", "torch"], marker = "extra == 'quantization'", specifier = ">=0.17.0" }, { name = "packaging", specifier = ">=23" }, { name = "rich", marker = "extra == 'monitoring-tools'", specifier = ">=13.7.1" }, { name = "rich", extras = ["jupyter"], marker = "extra == 'jupyter'", specifier = ">=13.7.1" }, @@ -1537,6 +2459,20 @@ wheels = [ { url = "https://download.pytorch.org/whl/test/urllib3-2.2.3-py3-none-any.whl" }, ] +[[package]] +name = "uvicorn" +version = "0.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "h11", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'windows')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/4d/938bd85e5bf2edeec766267a5015ad969730bb91e31b44021dfe8b22df6c/uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9", size = 76568 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315 }, +] + [[package]] name = "virtualenv" version = "20.26.5" @@ -1560,6 +2496,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, ] +[[package]] +name = "wheel" +version = "0.45.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494 }, +] + [[package]] name = "widgetsnbextension" version = "4.0.13" @@ -1569,6 +2514,107 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/21/02/88b65cc394961a60c43c70517066b6b679738caf78506a5da7b88ffcb643/widgetsnbextension-4.0.13-py3-none-any.whl", hash = "sha256:74b2692e8500525cc38c2b877236ba51d34541e6385eeed5aec15a70f88a6c71", size = 2335872 }, ] +[[package]] +name = "xxhash" +version = "3.5.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/xxhash-3.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, +] + +[[package]] +name = "yarl" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "multidict", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, + { name = "propcache", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/9d/4b94a8e6d2b51b599516a5cb88e5bc99b4d8d4583e468057eaa29d5f0918/yarl-1.18.3.tar.gz", hash = "sha256:ac1801c45cbf77b6c99242eeff4fffb5e4e73a800b5c4ad4fc0be5def634d2e1", size = 181062 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/f9/d616a5c2daae281171de10fba41e1c0e2d8207166fc3547252f7d469b4e1/yarl-1.18.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c654d5207c78e0bd6d749f6dae1dcbbfde3403ad3a4b11f3c5544d9906969dde", size = 315349 }, + { url = "https://files.pythonhosted.org/packages/bb/b4/3ea5e7b6f08f698b3769a06054783e434f6d59857181b5c4e145de83f59b/yarl-1.18.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5094d9206c64181d0f6e76ebd8fb2f8fe274950a63890ee9e0ebfd58bf9d787b", size = 330494 }, + { url = "https://files.pythonhosted.org/packages/55/f1/e0fc810554877b1b67420568afff51b967baed5b53bcc983ab164eebf9c9/yarl-1.18.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:35098b24e0327fc4ebdc8ffe336cee0a87a700c24ffed13161af80124b7dc8e5", size = 326927 }, + { url = "https://files.pythonhosted.org/packages/a9/42/b1753949b327b36f210899f2dd0a0947c0c74e42a32de3f8eb5c7d93edca/yarl-1.18.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3236da9272872443f81fedc389bace88408f64f89f75d1bdb2256069a8730ccc", size = 319703 }, + { url = "https://files.pythonhosted.org/packages/f0/6d/e87c62dc9635daefb064b56f5c97df55a2e9cc947a2b3afd4fd2f3b841c7/yarl-1.18.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2c08cc9b16f4f4bc522771d96734c7901e7ebef70c6c5c35dd0f10845270bcd", size = 310246 }, + { url = "https://files.pythonhosted.org/packages/e3/ef/e2e8d1785cdcbd986f7622d7f0098205f3644546da7919c24b95790ec65a/yarl-1.18.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:80316a8bd5109320d38eef8833ccf5f89608c9107d02d2a7f985f98ed6876990", size = 319730 }, + { url = "https://files.pythonhosted.org/packages/fc/15/8723e22345bc160dfde68c4b3ae8b236e868f9963c74015f1bc8a614101c/yarl-1.18.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c1e1cc06da1491e6734f0ea1e6294ce00792193c463350626571c287c9a704db", size = 321681 }, + { url = "https://files.pythonhosted.org/packages/86/09/bf764e974f1516efa0ae2801494a5951e959f1610dd41edbfc07e5e0f978/yarl-1.18.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fea09ca13323376a2fdfb353a5fa2e59f90cd18d7ca4eaa1fd31f0a8b4f91e62", size = 324812 }, + { url = "https://files.pythonhosted.org/packages/f6/4c/20a0187e3b903c97d857cf0272d687c1b08b03438968ae8ffc50fe78b0d6/yarl-1.18.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:e3b9fd71836999aad54084906f8663dffcd2a7fb5cdafd6c37713b2e72be1760", size = 337011 }, + { url = "https://files.pythonhosted.org/packages/c9/71/6244599a6e1cc4c9f73254a627234e0dad3883ece40cc33dce6265977461/yarl-1.18.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:757e81cae69244257d125ff31663249b3013b5dc0a8520d73694aed497fb195b", size = 338132 }, + { url = "https://files.pythonhosted.org/packages/af/f5/e0c3efaf74566c4b4a41cb76d27097df424052a064216beccae8d303c90f/yarl-1.18.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b1771de9944d875f1b98a745bc547e684b863abf8f8287da8466cf470ef52690", size = 331849 }, + { url = "https://files.pythonhosted.org/packages/ed/fe/88b690b30f3f59275fb674f5f93ddd4a3ae796c2b62e5bb9ece8a4914b83/yarl-1.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d980e0325b6eddc81331d3f4551e2a333999fb176fd153e075c6d1c2530aa8a8", size = 340649 }, + { url = "https://files.pythonhosted.org/packages/07/eb/3b65499b568e01f36e847cebdc8d7ccb51fff716dbda1ae83c3cbb8ca1c9/yarl-1.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b643562c12680b01e17239be267bc306bbc6aac1f34f6444d1bded0c5ce438ca", size = 356623 }, + { url = "https://files.pythonhosted.org/packages/33/46/f559dc184280b745fc76ec6b1954de2c55595f0ec0a7614238b9ebf69618/yarl-1.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c017a3b6df3a1bd45b9fa49a0f54005e53fbcad16633870104b66fa1a30a29d8", size = 354007 }, + { url = "https://files.pythonhosted.org/packages/af/ba/1865d85212351ad160f19fb99808acf23aab9a0f8ff31c8c9f1b4d671fc9/yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75674776d96d7b851b6498f17824ba17849d790a44d282929c42dbb77d4f17ae", size = 344145 }, + { url = "https://files.pythonhosted.org/packages/94/cb/5c3e975d77755d7b3d5193e92056b19d83752ea2da7ab394e22260a7b824/yarl-1.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ccaa3a4b521b780a7e771cc336a2dba389a0861592bbce09a476190bb0c8b4b3", size = 336133 }, + { url = "https://files.pythonhosted.org/packages/19/89/b77d3fd249ab52a5c40859815765d35c91425b6bb82e7427ab2f78f5ff55/yarl-1.18.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d06d3005e668744e11ed80812e61efd77d70bb7f03e33c1598c301eea20efbb", size = 347967 }, + { url = "https://files.pythonhosted.org/packages/35/bd/f6b7630ba2cc06c319c3235634c582a6ab014d52311e7d7c22f9518189b5/yarl-1.18.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:9d41beda9dc97ca9ab0b9888cb71f7539124bc05df02c0cff6e5acc5a19dcc6e", size = 346397 }, + { url = "https://files.pythonhosted.org/packages/18/1a/0b4e367d5a72d1f095318344848e93ea70da728118221f84f1bf6c1e39e7/yarl-1.18.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ba23302c0c61a9999784e73809427c9dbedd79f66a13d84ad1b1943802eaaf59", size = 350206 }, + { url = "https://files.pythonhosted.org/packages/b5/cf/320fff4367341fb77809a2d8d7fe75b5d323a8e1b35710aafe41fdbf327b/yarl-1.18.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6748dbf9bfa5ba1afcc7556b71cda0d7ce5f24768043a02a58846e4a443d808d", size = 362089 }, + { url = "https://files.pythonhosted.org/packages/57/cf/aadba261d8b920253204085268bad5e8cdd86b50162fcb1b10c10834885a/yarl-1.18.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0b0cad37311123211dc91eadcb322ef4d4a66008d3e1bdc404808992260e1a0e", size = 366267 }, + { url = "https://files.pythonhosted.org/packages/54/58/fb4cadd81acdee6dafe14abeb258f876e4dd410518099ae9a35c88d8097c/yarl-1.18.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fb2171a4486bb075316ee754c6d8382ea6eb8b399d4ec62fde2b591f879778a", size = 359141 }, + { url = "https://files.pythonhosted.org/packages/6b/32/927b2d67a412c31199e83fefdce6e645247b4fb164aa1ecb35a0f9eb2058/yarl-1.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436c4fc0a4d66b2badc6c5fc5ef4e47bb10e4fd9bf0c79524ac719a01f3607c2", size = 332368 }, + { url = "https://files.pythonhosted.org/packages/19/e5/859fca07169d6eceeaa4fde1997c91d8abde4e9a7c018e371640c2da2b71/yarl-1.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e35ef8683211db69ffe129a25d5634319a677570ab6b2eba4afa860f54eeaf75", size = 342314 }, + { url = "https://files.pythonhosted.org/packages/08/75/76b63ccd91c9e03ab213ef27ae6add2e3400e77e5cdddf8ed2dbc36e3f21/yarl-1.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84b2deecba4a3f1a398df819151eb72d29bfeb3b69abb145a00ddc8d30094512", size = 341987 }, + { url = "https://files.pythonhosted.org/packages/1a/e1/a097d5755d3ea8479a42856f51d97eeff7a3a7160593332d98f2709b3580/yarl-1.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e5a1fea0fd4f5bfa7440a47eff01d9822a65b4488f7cff83155a0f31a2ecba", size = 336914 }, + { url = "https://files.pythonhosted.org/packages/0b/42/e1b4d0e396b7987feceebe565286c27bc085bf07d61a59508cdaf2d45e63/yarl-1.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0e883008013c0e4aef84dcfe2a0b172c4d23c2669412cf5b3371003941f72bb", size = 325765 }, + { url = "https://files.pythonhosted.org/packages/7e/18/03a5834ccc9177f97ca1bbb245b93c13e58e8225276f01eedc4cc98ab820/yarl-1.18.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5a3f356548e34a70b0172d8890006c37be92995f62d95a07b4a42e90fba54272", size = 344444 }, + { url = "https://files.pythonhosted.org/packages/c8/03/a713633bdde0640b0472aa197b5b86e90fbc4c5bc05b727b714cd8a40e6d/yarl-1.18.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ccd17349166b1bee6e529b4add61727d3f55edb7babbe4069b5764c9587a8cc6", size = 340760 }, + { url = "https://files.pythonhosted.org/packages/eb/99/f6567e3f3bbad8fd101886ea0276c68ecb86a2b58be0f64077396cd4b95e/yarl-1.18.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b958ddd075ddba5b09bb0be8a6d9906d2ce933aee81100db289badbeb966f54e", size = 346484 }, + { url = "https://files.pythonhosted.org/packages/8e/a9/84717c896b2fc6cb15bd4eecd64e34a2f0a9fd6669e69170c73a8b46795a/yarl-1.18.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c7d79f7d9aabd6011004e33b22bc13056a3e3fb54794d138af57f5ee9d9032cb", size = 359864 }, + { url = "https://files.pythonhosted.org/packages/1e/2e/d0f5f1bef7ee93ed17e739ec8dbcb47794af891f7d165fa6014517b48169/yarl-1.18.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4891ed92157e5430874dad17b15eb1fda57627710756c27422200c52d8a4e393", size = 364537 }, + { url = "https://files.pythonhosted.org/packages/97/8a/568d07c5d4964da5b02621a517532adb8ec5ba181ad1687191fffeda0ab6/yarl-1.18.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce1af883b94304f493698b00d0f006d56aea98aeb49d75ec7d98cd4a777e9285", size = 357861 }, + { url = "https://files.pythonhosted.org/packages/56/4e/d2563d8323a7e9a414b5b25341b3942af5902a2263d36d20fb17c40411e2/yarl-1.18.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88a19f62ff30117e706ebc9090b8ecc79aeb77d0b1f5ec10d2d27a12bc9f66d0", size = 333587 }, + { url = "https://files.pythonhosted.org/packages/25/c9/cfec0bc0cac8d054be223e9f2c7909d3e8442a856af9dbce7e3442a8ec8d/yarl-1.18.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e17c9361d46a4d5addf777c6dd5eab0715a7684c2f11b88c67ac37edfba6c482", size = 344386 }, + { url = "https://files.pythonhosted.org/packages/ab/5d/4c532190113b25f1364d25f4c319322e86232d69175b91f27e3ebc2caf9a/yarl-1.18.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1a74a13a4c857a84a845505fd2d68e54826a2cd01935a96efb1e9d86c728e186", size = 345421 }, + { url = "https://files.pythonhosted.org/packages/23/d1/6cdd1632da013aa6ba18cee4d750d953104a5e7aac44e249d9410a972bf5/yarl-1.18.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41f7ce59d6ee7741af71d82020346af364949314ed3d87553763a2df1829cc58", size = 339384 }, + { url = "https://files.pythonhosted.org/packages/9a/c4/6b3c39bec352e441bd30f432cda6ba51681ab19bb8abe023f0d19777aad1/yarl-1.18.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f52a265001d830bc425f82ca9eabda94a64a4d753b07d623a9f2863fde532b53", size = 326689 }, + { url = "https://files.pythonhosted.org/packages/23/30/07fb088f2eefdc0aa4fc1af4e3ca4eb1a3aadd1ce7d866d74c0f124e6a85/yarl-1.18.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:82123d0c954dc58db301f5021a01854a85bf1f3bb7d12ae0c01afc414a882ca2", size = 345453 }, + { url = "https://files.pythonhosted.org/packages/63/09/d54befb48f9cd8eec43797f624ec37783a0266855f4930a91e3d5c7717f8/yarl-1.18.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:2ec9bbba33b2d00999af4631a3397d1fd78290c48e2a3e52d8dd72db3a067ac8", size = 341872 }, + { url = "https://files.pythonhosted.org/packages/91/26/fd0ef9bf29dd906a84b59f0cd1281e65b0c3e08c6aa94b57f7d11f593518/yarl-1.18.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:fbd6748e8ab9b41171bb95c6142faf068f5ef1511935a0aa07025438dd9a9bc1", size = 347497 }, + { url = "https://files.pythonhosted.org/packages/d9/b5/14ac7a256d0511b2ac168d50d4b7d744aea1c1aa20c79f620d1059aab8b2/yarl-1.18.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:877d209b6aebeb5b16c42cbb377f5f94d9e556626b1bfff66d7b0d115be88d0a", size = 359981 }, + { url = "https://files.pythonhosted.org/packages/ca/b3/d493221ad5cbd18bc07e642894030437e405e1413c4236dd5db6e46bcec9/yarl-1.18.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b464c4ab4bfcb41e3bfd3f1c26600d038376c2de3297760dfe064d2cb7ea8e10", size = 366229 }, + { url = "https://files.pythonhosted.org/packages/04/56/6a3e2a5d9152c56c346df9b8fb8edd2c8888b1e03f96324d457e5cf06d34/yarl-1.18.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8d39d351e7faf01483cc7ff7c0213c412e38e5a340238826be7e0e4da450fdc8", size = 360383 }, + { url = "https://files.pythonhosted.org/packages/0f/4f/438c9fd668954779e48f08c0688ee25e0673380a21bb1e8ccc56de5b55d7/yarl-1.18.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c7907c8548bcd6ab860e5f513e727c53b4a714f459b084f6580b49fa1b9cee", size = 317327 }, + { url = "https://files.pythonhosted.org/packages/bd/79/a78066f06179b4ed4581186c136c12fcfb928c475cbeb23743e71a991935/yarl-1.18.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b4f6450109834af88cb4cc5ecddfc5380ebb9c228695afc11915a0bf82116789", size = 336999 }, + { url = "https://files.pythonhosted.org/packages/55/02/527963cf65f34a06aed1e766ff9a3b3e7d0eaa1c90736b2948a62e528e1d/yarl-1.18.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a9ca04806f3be0ac6d558fffc2fdf8fcef767e0489d2684a21912cc4ed0cd1b8", size = 331693 }, + { url = "https://files.pythonhosted.org/packages/a2/2a/167447ae39252ba624b98b8c13c0ba35994d40d9110e8a724c83dbbb5822/yarl-1.18.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77a6e85b90a7641d2e07184df5557132a337f136250caafc9ccaa4a2a998ca2c", size = 321473 }, + { url = "https://files.pythonhosted.org/packages/55/03/07955fabb20082373be311c91fd78abe458bc7ff9069d34385e8bddad20e/yarl-1.18.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6333c5a377c8e2f5fae35e7b8f145c617b02c939d04110c76f29ee3676b5f9a5", size = 313571 }, + { url = "https://files.pythonhosted.org/packages/95/e2/67c8d3ec58a8cd8ddb1d63bd06eb7e7b91c9f148707a3eeb5a7ed87df0ef/yarl-1.18.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0b3c92fa08759dbf12b3a59579a4096ba9af8dd344d9a813fc7f5070d86bbab1", size = 325004 }, + { url = "https://files.pythonhosted.org/packages/06/43/51ceb3e427368fe6ccd9eccd162be227fd082523e02bad1fd3063daf68da/yarl-1.18.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:4ac515b860c36becb81bb84b667466885096b5fc85596948548b667da3bf9f24", size = 322677 }, + { url = "https://files.pythonhosted.org/packages/e4/0e/7ef286bfb23267739a703f7b967a858e2128c10bea898de8fa027e962521/yarl-1.18.3-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:045b8482ce9483ada4f3f23b3774f4e1bf4f23a2d5c912ed5170f68efb053318", size = 332806 }, + { url = "https://files.pythonhosted.org/packages/c8/94/2d1f060f4bfa47c8bd0bcb652bfe71fba881564bcac06ebb6d8ced9ac3bc/yarl-1.18.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:a4bb030cf46a434ec0225bddbebd4b89e6471814ca851abb8696170adb163985", size = 339919 }, + { url = "https://files.pythonhosted.org/packages/8e/8d/73b5f9a6ab69acddf1ca1d5e7bc92f50b69124512e6c26b36844531d7f23/yarl-1.18.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:54d6921f07555713b9300bee9c50fb46e57e2e639027089b1d795ecd9f7fa910", size = 340960 }, + { url = "https://files.pythonhosted.org/packages/41/13/ce6bc32be4476b60f4f8694831f49590884b2c975afcffc8d533bf2be7ec/yarl-1.18.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1d407181cfa6e70077df3377938c08012d18893f9f20e92f7d2f314a437c30b1", size = 336592 }, + { url = "https://files.pythonhosted.org/packages/f5/4b/a06e0ec3d155924f77835ed2d167ebd3b211a7b0853da1cf8d8414d784ef/yarl-1.18.3-py3-none-any.whl", hash = "sha256:b57f4f58099328dfb26c6a771d09fb20dbbae81d20cfb66141251ea063bd101b", size = 45109 }, +] + [[package]] name = "zipp" version = "3.20.2"