Skip to content

Commit

Permalink
Add illustration (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdiddy77 authored Sep 7, 2024
1 parent 6387829 commit 7b347f6
Show file tree
Hide file tree
Showing 18 changed files with 488 additions and 148 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
OPENAI_API_KEY=the OPENAI_API_KEY
WANDB_API_KEY=the WANDB_API_KEY
STABILITYAI_API_KEY=the STABILITYAI_API_KEY
ILLUSTRATION_PATH=path/to/the/illustrations
102 changes: 51 additions & 51 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pandas = "^2.2.2"
wandb = "^0.17.1"
asyncio = "^3.4.3"
tabulate = "^0.9.0"
pymongo = "^4.7.3"

[tool.poetry.group.dev.dependencies]
mkdocs-material = "^9.5.26"
Expand All @@ -31,6 +32,7 @@ optional = true
python-dotenv = "^1.0.1"
fastapi = "^0.111.0"
motor = "^3.4.0"
pymongo = "^4.7.3"

[build-system]
requires = ["poetry-core"]
Expand Down
12 changes: 12 additions & 0 deletions webui/api-server/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from worldender.agents.gm import GameMaster
from worldender.models.choice import Choice
from worldender.models.illustration import Illustration
from worldender.models.location import Location
from worldender.models.outcome import Outcome
from worldender.models.player import Player
Expand Down Expand Up @@ -44,3 +45,14 @@ class GMResponse(BaseModel):

class GMRequest(BaseModel):
query: str


class NewIllustrationRequest(BaseModel):
prompt: str
negative_prompt: str
aspect_ratio: str


class NewIllustrationResponse(BaseModel):
id: str
result: Illustration
98 changes: 98 additions & 0 deletions webui/api-server/illustrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import asyncio
import requests
import logging
from worldender.models.illustration import Illustration
from worldender.config import app_config
from typing import Callable
from .storage import store_illustration_error, store_illustration_filepath

logger = logging.getLogger(__name__)


def send_generation_request(
host,
params,
):
headers = {
"Accept": "image/*",
"Authorization": f"Bearer {app_config.stabilityai_api_key}",
}

# Encode parameters
files = {}
image = params.pop("image", None)
mask = params.pop("mask", None)
if image is not None and image != "":
files["image"] = open(image, "rb")
if mask is not None and mask != "":
files["mask"] = open(mask, "rb")
if len(files) == 0:
files["none"] = ""

# Send request
print(f"Sending REST request to {host}...")
response = requests.post(host, headers=headers, files=files, data=params)
if not response.ok:
raise Exception(f"HTTP {response.status_code}: {response.text}")

return response


def gen_illustrate(
id: str,
prompt: str,
negative_prompt="",
aspect_ratio="1:1",
output_format="png",
) -> Illustration:
"""
someone else has created and stored the illustration. This
code asynchronously requests the actual image from stability
once the image is ready, store the pixels in a path on the file system
and update the illustration object in the database with the path
and new state
"""
try:
seed = 0 # @param {type:"integer"}
model = (
"sd3-large-turbo" # @param ["sd3-large", "sd3-large-turbo", "sd3-medium"]
)

host = f"https://api.stability.ai/v2beta/stable-image/generate/sd3"

params = {
"prompt": prompt,
"negative_prompt": negative_prompt if model == "sd3" else "",
"aspect_ratio": aspect_ratio,
"seed": seed,
"output_format": output_format,
"model": model,
"mode": "text-to-image",
}
response = send_generation_request(host, params)

# Decode response
output_image = response.content
finish_reason = response.headers.get("finish-reason")
seed = response.headers.get("seed")

# Check for NSFW classification
if finish_reason == "CONTENT_FILTERED":
raise Warning("Generation failed NSFW classifier")

# Save and display result
file_name = os.path.join(
app_config.illustration_path, f"generated_{seed}.{output_format}"
)
with open(file_name, "wb") as f:
f.write(output_image)
logger.info(f"Saved image {file_name}")

# Update illustration object
# asyncio.run(store_illustration_filepath(id, file_name))
store_illustration_filepath(id, file_name)
# loop.run_until_complete(on_complete(id, file_name))
except Exception as e:
store_illustration_error(id)
raise e
50 changes: 47 additions & 3 deletions webui/api-server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from worldender.models.game_plan import GamePlan
from worldender.models.question import Question
from worldender.models.question_response import QuestionResponse
from .illustrate import gen_illustrate

load_dotenv()

import openai
import instructor
from fastapi import Body, FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi import Body, FastAPI, Request, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from worldender.models.location import Location
from worldender.models.world_ender import WorldEnder
Expand All @@ -25,14 +26,16 @@
Choice,
GMResponse,
GMRequest,
NewIllustrationRequest,
NewIllustrationResponse,
NewScenarioRequest,
NewScenarioResponse,
Event,
Player,
Scenario,
World,
)
from .util import gen_scenario_id
from .util import gen_illustration_id, gen_scenario_id
from .error_handling import *
from .storage import *

Expand Down Expand Up @@ -144,6 +147,47 @@ async def post_ask_question(scenario_id: str, data: Question) -> QuestionRespons
return question_response


@app.post("/illustration/new", response_model=NewIllustrationResponse)
async def new_illustration(
req: NewIllustrationRequest, bgtasks: BackgroundTasks
) -> NewIllustrationResponse:
id = gen_illustration_id()
illustration = Illustration(
prompt=req.prompt,
negative_prompt=req.negative_prompt,
aspect_ratio=req.aspect_ratio,
image_path="",
progress="inprogress",
)

await store_illustration(id, illustration)

# specifically dont await this
bgtasks.add_task(
gen_illustrate,
id,
req.prompt,
negative_prompt=req.negative_prompt,
aspect_ratio=req.aspect_ratio,
)
return NewIllustrationResponse(id=id, result=illustration)


@app.get("/illustration/{illustration_id}")
async def get_illustration(illustration_id: str):
illustration = await fetch_illustration(illustration_id)
del illustration.image_path
return illustration


@app.get("/illustration/{illustration_id}/image")
async def get_illustration_image(illustration_id: str):
illustration = await fetch_illustration(illustration_id)
if illustration.progress != "complete":
raise NotFoundException(f"Illustration with id {illustration_id} not ready")
return FileResponse(illustration.image_path, media_type="image/png")


@app.exception_handler(CanonicalException)
async def canonical_exception_handler(request: Request, exc: CanonicalException):
logger.info(f"Caught exception: {exc}")
Expand Down
56 changes: 55 additions & 1 deletion webui/api-server/storage.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import logging
from pymongo import MongoClient
from worldender.models.illustration import Illustration
from .dtos import Scenario
from .error_handling import *
from .util import get_id_from_slug

import motor.motor_asyncio
from bson.objectid import ObjectId

logger = logging.getLogger(__name__)

client = motor.motor_asyncio.AsyncIOMotorClient("mongodb://localhost:27017/")
db = client["worldender"]
scenarios_collection = db["scenarios"]
illustrations_collection = db["illustrations"]

# Insert a Pydantic object


async def fetch_scenario(slug: str) -> Scenario:
id = get_id_from_slug(slug)
print(f"fetching scenario with id {id}")
logger.info(f"fetching scenario with id {id}")
dict = await scenarios_collection.find_one({"_id": ObjectId(id)})
if dict:
return Scenario(**dict)
Expand All @@ -32,3 +38,51 @@ async def store_scenario(id: str, scenario: Scenario):
return True
else:
raise UnknownException(f"Failed to store scenario with id {id}")


async def fetch_illustration(id: str) -> Illustration:
dict = await illustrations_collection.find_one({"_id": ObjectId(id)})
if dict:
return Illustration(**dict)
else:
raise NotFoundException(f"Illustration with id {id} not found")


async def store_illustration(id: str, illustration: Illustration):
logger.info(f"storing illustration with id {id}")
dict = illustration.model_dump()
result = await illustrations_collection.update_one(
{"_id": ObjectId(id)}, {"$set": dict}, upsert=True
)
if result.modified_count == 1 or result.upserted_id is not None:
return True
else:
raise UnknownException(f"Failed to store illustration with id {id}")


sync_client = MongoClient("mongodb://localhost:27017")
sync_db = sync_client.worldender
sync_illustrations_collection = sync_db["illustrations"]


def store_illustration_filepath(id: str, file_name: str):
logger.info(f"storing illustration filepath with id {id}")
result = sync_illustrations_collection.update_one(
{"_id": ObjectId(id)},
{"$set": {"image_path": file_name, "progress": "complete"}},
)
if result.modified_count == 1 or result.upserted_id is not None:
return True
else:
raise UnknownException(f"Failed to store illustration with id {id}")


def store_illustration_error(id: str):
logger.info(f"storing illustration filepath with id {id}")
result = sync_illustrations_collection.update_one(
{"_id": ObjectId(id)}, {"$set": {"progress": "failed"}}
)
if result.modified_count == 1 or result.upserted_id is not None:
return True
else:
raise UnknownException(f"Failed to store illustration with id {id}")
6 changes: 6 additions & 0 deletions webui/api-server/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import re
import unicodedata
import uuid
from .dtos import NewScenarioRequest


Expand All @@ -19,6 +20,11 @@ def gen_scenario_id(req: NewScenarioRequest) -> tuple[str, str]:
return slug, hash_slug.hexdigest()[0:24]


def gen_illustration_id() -> str:
# return guid in string form
return str(uuid.uuid4()).replace("-", "")[0:24]


def get_id_from_slug(slug: str) -> str:
"""Get the id from a slug."""
return hashlib.sha256(slug.encode()).hexdigest()[0:24]
Expand Down
1 change: 1 addition & 0 deletions webui/apps/website/app/scenario/new/page.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"use client";
import { NewScenario } from "@/components/newScenario/newScenario";

export default function ScenarioNew() {
Expand Down
Loading

0 comments on commit 7b347f6

Please sign in to comment.