Skip to content

Commit

Permalink
AsyncConcatenateIterator
Browse files Browse the repository at this point in the history
Signed-off-by: technillogue <[email protected]>
  • Loading branch information
technillogue committed May 8, 2024
1 parent bb01c85 commit 12b0abe
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
3 changes: 2 additions & 1 deletion python/cog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel

from .predictor import BasePredictor
from .types import ConcatenateIterator, File, Input, Path
from .types import AsyncConcatenateIterator, ConcatenateIterator, File, Input, Path

try:
from ._version import __version__
Expand All @@ -14,6 +14,7 @@
"BaseModel",
"BasePredictor",
"ConcatenateIterator",
"AsyncConcatenateIterator",
"File",
"Input",
"Path",
Expand Down
33 changes: 24 additions & 9 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import urllib.parse
import urllib.request
import urllib.response
from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, TypeVar, Union

import httpx
import requests
Expand Down Expand Up @@ -255,21 +255,20 @@ def __repr__(self) -> str:


Item = TypeVar("Item")
_concatenate_iterator_schema = {
"type": "array",
"items": {"type": "string"},
"x-cog-array-type": "iterator",
"x-cog-array-display": "concatenate",
}


class ConcatenateIterator(Iterator[Item]):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.pop("allOf", None)
field_schema.update(
{
"type": "array",
"items": {"type": "string"},
"x-cog-array-type": "iterator",
"x-cog-array-display": "concatenate",
}
)
field_schema.update(_concatenate_iterator_schema)

@classmethod
def __get_validators__(cls) -> Iterator[Any]:
Expand All @@ -280,6 +279,22 @@ def validate(cls, value: Iterator[Any]) -> Iterator[Any]:
return value


class AsyncConcatenateIterator(AsyncIterator[Item]):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
"""Defines what this type should be in openapi.json"""
field_schema.pop("allOf", None)
field_schema.update(_concatenate_iterator_schema)

@classmethod
def __get_validators__(cls) -> Iterator[Any]:
yield cls.validate

@classmethod
def validate(cls, value: AsyncIterator[Any]) -> AsyncIterator[Any]:
return value


def _len_bytes(s: str, encoding: str = "utf-8") -> int:
return len(s.encode(encoding))

Expand Down

0 comments on commit 12b0abe

Please sign in to comment.