diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 34b3729550..31677280df 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -1,7 +1,7 @@ from pydantic import BaseModel from .predictor import BasePredictor -from .types import ConcatenateIterator, File, Input, Path +from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path try: from ._version import __version__ @@ -14,6 +14,7 @@ "BaseModel", "BasePredictor", "ConcatenateIterator", + "AsyncConcatenateIterator", "File", "Input", "Path", diff --git a/python/cog/types.py b/python/cog/types.py index 2bf54a7437..3c7d45a52b 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -7,7 +7,7 @@ import urllib.parse import urllib.request import urllib.response -from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, TypeVar, Union import httpx import requests @@ -255,6 +255,12 @@ def __repr__(self) -> str: Item = TypeVar("Item") +_concatenate_iterator_schema = { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", +} class ConcatenateIterator(Iterator[Item]): @@ -262,14 +268,7 @@ class ConcatenateIterator(Iterator[Item]): def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: """Defines what this type should be in openapi.json""" field_schema.pop("allOf", None) - field_schema.update( - { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - } - ) + field_schema.update(_concatenate_iterator_schema) @classmethod def __get_validators__(cls) -> Iterator[Any]: @@ -280,6 +279,22 @@ def validate(cls, value: Iterator[Any]) -> Iterator[Any]: return value +class AsyncConcatenateIterator(AsyncIterator[Item]): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + """Defines what this type should be in openapi.json""" + field_schema.pop("allOf", None) + field_schema.update(_concatenate_iterator_schema) + + @classmethod + def __get_validators__(cls) -> Iterator[Any]: + yield cls.validate + + @classmethod + def validate(cls, value: AsyncIterator[Any]) -> AsyncIterator[Any]: + return value + + def _len_bytes(s: str, encoding: str = "utf-8") -> int: return len(s.encode(encoding))