Skip to content

Commit

Permalink
ready
Browse files Browse the repository at this point in the history
  • Loading branch information
alexandruvesa committed Aug 5, 2024
1 parent 9f84695 commit 6d84fcb
Show file tree
Hide file tree
Showing 25 changed files with 2,636 additions and 400 deletions.
Binary file added .DS_Store
Binary file not shown.
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.PHONY: deploy-inference-endpoint


create-sagemaker-role:
poetry run python llm_engineering/core/aws/create_sagemaker_role.py

create-sagemaker-execution-role:
poetry run python llm_engineering/core/aws/create_sagemaker_execution_role.py

deploy-inference-endpoint:
poetry run python llm_engineering/model/deploy/huggingface/run.py

4 changes: 2 additions & 2 deletions llm_engineering/application/dataset/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DatasetGenerator:
@classmethod
def get_system_prompt(cls) -> Prompt:
return Prompt(
template=cls.system_prompt_template,
template=PromptTemplate.from_template(cls.system_prompt_template),
input_variables={},
content=cls.system_prompt_template,
)
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_prompt(cls, documents: list[CleanedDocument]) -> GenerateDatasetSamplesP
prompt = cls.tokenizer.decode(prompt_tokens)

prompt = GenerateDatasetSamplesPrompt(
template=prompt_template.template,
template=prompt_template,
input_variables=input_variables,
content=prompt,
num_tokens=len(prompt_tokens),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Generic, TypeVar
from uuid import UUID

from llm_engineering.domain.chunks import ArticleChunk, Chunk, PostChunk, RepositoryChunk
from llm_engineering.domain.chunks import (
ArticleChunk,
Chunk,
PostChunk,
RepositoryChunk,
)
from llm_engineering.domain.cleaned_documents import (
CleanedArticleDocument,
CleanedDocument,
Expand Down Expand Up @@ -37,19 +42,17 @@ def chunk(self, data_model: CleanedDocumentT) -> list[ChunkT]:


class PostChunkingHandler(ChunkingDataHandler):
@property
def chunk_size(self) -> int:
return 250

@property
def chunk_overlap(self) -> int:
return 25

def chunk(self, data_model: CleanedPostDocument) -> list[PostChunk]:
data_models_list = []

cleaned_content = data_model.content
chunks = chunk_text(cleaned_content, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
chunks = chunk_text(cleaned_content)

for chunk in chunks:
chunk_id = hashlib.md5(chunk.encode()).hexdigest()
Expand All @@ -76,7 +79,7 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]:
data_models_list = []

cleaned_content = data_model.content
chunks = chunk_text(cleaned_content, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
chunks = chunk_text(cleaned_content)

for chunk in chunks:
chunk_id = hashlib.md5(chunk.encode()).hexdigest()
Expand All @@ -99,19 +102,17 @@ def chunk(self, data_model: CleanedArticleDocument) -> list[ArticleChunk]:


class RepositoryChunkingHandler(ChunkingDataHandler):
@property
def chunk_size(self) -> int:
return 750

@property
def chunk_overlap(self) -> int:
return 75

def chunk(self, data_model: CleanedRepositoryDocument) -> list[RepositoryChunk]:
data_models_list = []

cleaned_content = data_model.content
chunks = chunk_text(cleaned_content, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
chunks = chunk_text(cleaned_content)

for chunk in chunks:
chunk_id = hashlib.md5(chunk.encode()).hexdigest()
Expand Down
Empty file.
63 changes: 63 additions & 0 deletions llm_engineering/core/aws/roles/create_execution_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import json

import boto3

from llm_engineering.settings import settings


def create_sagemaker_execution_role(role_name, region_name="eu-central-1"):
# Create IAM client
iam = boto3.client(
"iam",
region_name=region_name,
aws_access_key_id=settings.AWS_ACCESS_KEY,
aws_secret_access_key=settings.AWS_SECRET_KEY,
)

# Define the trust relationship policy
trust_relationship = {
"Version": "2012-10-17",
"Statement": [
{"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"}
],
}

try:
# Create the IAM role
role = iam.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(trust_relationship),
Description="Execution role for SageMaker",
)

# Attach necessary policies
policies = [
"arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
"arn:aws:iam::aws:policy/AmazonS3FullAccess",
"arn:aws:iam::aws:policy/CloudWatchLogsFullAccess",
"arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess",
]

for policy in policies:
iam.attach_role_policy(RoleName=role_name, PolicyArn=policy)

print(f"Role '{role_name}' created successfully.")
print(f"Role ARN: {role['Role']['Arn']}")

return role["Role"]["Arn"]

except iam.exceptions.EntityAlreadyExistsException:
print(f"Role '{role_name}' already exists. Fetching its ARN...")
role = iam.get_role(RoleName=role_name)
return role["Role"]["Arn"]


if __name__ == "__main__":
role_arn = create_sagemaker_execution_role("SageMakerExecutionRoleLLM")
print(role_arn)

# Save the role ARN to a file
with open("sagemaker_execution_role.json", "w") as f:
json.dump({"RoleArn": role_arn}, f)

print("Role ARN saved to 'sagemaker_execution_role.json'")
51 changes: 51 additions & 0 deletions llm_engineering/core/aws/roles/create_sagemaker_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json

import boto3

from llm_engineering.settings import settings


def create_sagemaker_user(username, region_name="eu-central-1"):
# Create IAM client
iam = boto3.client(
"iam",
region_name=region_name,
aws_access_key_id=settings.AWS_ACCESS_KEY,
aws_secret_access_key=settings.AWS_SECRET_KEY,
)

# Create user
iam.create_user(UserName=username)

# Attach necessary policies
policies = [
"arn:aws:iam::aws:policy/AmazonSageMakerFullAccess",
"arn:aws:iam::aws:policy/AWSCloudFormationFullAccess",
"arn:aws:iam::aws:policy/IAMFullAccess",
"arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess",
"arn:aws:iam::aws:policy/AmazonS3FullAccess",
]

for policy in policies:
iam.attach_user_policy(UserName=username, PolicyArn=policy)

# Create access key
response = iam.create_access_key(UserName=username)
access_key = response["AccessKey"]

print(f"User '{username}' created successfully.")
print(f"Access Key ID: {access_key['AccessKeyId']}")
print(f"Secret Access Key: {access_key['SecretAccessKey']}")

# Return the access key info
return {"AccessKeyId": access_key["AccessKeyId"], "SecretAccessKey": access_key["SecretAccessKey"]}


if __name__ == "__main__":
new_user = create_sagemaker_user("sagemaker-deployer-2")

# Save the access keys to a file
with open("sagemaker_user_credentials.json", "w") as f:
json.dump(new_user, f)

print("Credentials saved to 'sagemaker_user_credentials.json'")
59 changes: 59 additions & 0 deletions llm_engineering/core/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import Any, Dict

from pydantic import BaseModel


class LLMInterface(ABC):
def __init__(self, model: str):
self.model = model

@abstractmethod
def get_answer(self, prompt: str, *args, **kwargs):
pass


class BasePromptTemplate(ABC, BaseModel):
@abstractmethod
def create_template(self, *args) -> str:
pass


class DeploymentStrategy(ABC):
@abstractmethod
def deploy(self, model, endpoint_name: str, endpoint_config_name: str) -> None:
pass


class Inference(ABC):
"""An abstract class for performing inference."""

def __init__(self):
self.model = None

@abstractmethod
def set_payload(self, inputs, parameters=None):
pass

@abstractmethod
def inference(self):
pass


class Summarize(ABC):
"""A class for summarizing documents."""

def __init__(self, llm: Inference):
self.llm = llm

@abstractmethod
def summarize(self, document_structure: dict):
pass


class Task:
"""An abstract class for performing a task."""

def execute(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Executes the task."""
raise NotImplementedError
47 changes: 47 additions & 0 deletions llm_engineering/core/langchain_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

from langchain import PromptTemplate
from langchain.chains import LLMChain
from langchain.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
# Structure the payload according to your inference example

input_payload = {"inputs": prompt, "parameters": model_kwargs}
input_str = json.dumps(input_payload)
return input_str.encode("utf-8")

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
if isinstance(response_json, list) and len(response_json) > 0:
full_text = response_json[0].get("generated_text", "")

# Split the text based on a unique delimiter (e.g., "SUMMARY:")
parts = full_text.split("SUMMARY:")
if len(parts) > 1:
# Return only the part after the delimiter
generated_summary = parts[1]
return generated_summary.strip()
else:
print("Delimiter 'SUMMARY:' not found in the response")
return ""
else:
print("Unexpected response format or empty response:", response_json)
return ""


class GeneralChain:
@staticmethod
def get_chain(llm, template: str, input_variables=None, verbose=True, output_key=""):
prompt_template = PromptTemplate(input_variables=input_variables, template=template, verbose=verbose)
return LLMChain(
llm=llm,
prompt=prompt_template,
output_key=output_key,
verbose=verbose,
)
6 changes: 3 additions & 3 deletions llm_engineering/domain/base/nosql.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def to_mongo(self: T, **kwargs) -> dict:
exclude_unset = kwargs.pop("exclude_unset", False)
by_alias = kwargs.pop("by_alias", True)

parsed = self.model_dump(exclude_unset=exclude_unset, by_alias=by_alias, **kwargs)
parsed = self.dict(exclude_unset=exclude_unset, by_alias=by_alias, **kwargs)

if "_id" not in parsed and "id" in parsed:
parsed["_id"] = str(parsed.pop("id"))
Expand All @@ -55,8 +55,8 @@ def to_mongo(self: T, **kwargs) -> dict:

return parsed

def model_dump(self: T, **kwargs) -> dict:
dict_ = super().model_dump(**kwargs)
def dict(self: T, **kwargs) -> dict:
dict_ = super().dict(**kwargs)

for key, value in dict_.items():
if isinstance(value, uuid.UUID):
Expand Down
6 changes: 3 additions & 3 deletions llm_engineering/domain/base/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def to_point(self: T, **kwargs) -> PointStruct:
exclude_unset = kwargs.pop("exclude_unset", False)
by_alias = kwargs.pop("by_alias", True)

payload = self.model_dump(exclude_unset=exclude_unset, by_alias=by_alias, **kwargs)
payload = self.dict(exclude_unset=exclude_unset, by_alias=by_alias, **kwargs)

_id = str(payload.pop("id"))
vector = payload.pop("embedding", {})
Expand All @@ -57,8 +57,8 @@ def to_point(self: T, **kwargs) -> PointStruct:

return PointStruct(id=_id, vector=vector, payload=payload)

def model_dump(self: T, **kwargs) -> dict:
dict_ = super().model_dump(**kwargs)
def dict(self: T, **kwargs) -> dict:
dict_ = super().dict(**kwargs)

for key, value in dict_.items():
if isinstance(value, UUID):
Expand Down
5 changes: 3 additions & 2 deletions llm_engineering/domain/prompt.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from langchain_core.prompts import PromptTemplate

from llm_engineering.domain.base import VectorBaseDocument
from llm_engineering.domain.cleaned_documents import CleanedDocument
from llm_engineering.domain.types import DataCategory


class Prompt(VectorBaseDocument):
template: str
template: PromptTemplate
input_variables: dict
content: str
num_tokens: int | None = None

class Config:
category = DataCategory.PROMPT
arbitrary_types_allowed = True


class GenerateDatasetSamplesPrompt(Prompt):
Expand Down
Loading

0 comments on commit 6d84fcb

Please sign in to comment.