Skip to content

Commit

Permalink
Dataset router (#54)
Browse files Browse the repository at this point in the history
* add dataset route

Signed-off-by: cbh778899 <[email protected]>

* prevent reload when dataset loaded

Signed-off-by: cbh778899 <[email protected]>

* implement load dataset route

Signed-off-by: cbh778899 <[email protected]>

* add functions related to api keys

Signed-off-by: cbh778899 <[email protected]>

* implement new load function & bug fix

Signed-off-by: cbh778899 <[email protected]>

* implement new route & update examples

Signed-off-by: cbh778899 <[email protected]>

* fix description of embedding tag

Signed-off-by: cbh778899 <[email protected]>

* add NUM_THREADS_COUNT in env file

Signed-off-by: cbh778899 <[email protected]>

* fix bugs not await calculateEmbedding()

Signed-off-by: cbh778899 <[email protected]>

* add error handle loading dataset

Signed-off-by: cbh778899 <[email protected]>

* delete records before force load dataset & add opt param to prevent this action

Signed-off-by: cbh778899 <[email protected]>

---------

Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 authored Aug 16, 2024
1 parent 3027fb1 commit 5a67ed8
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 37 deletions.
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ INFERENCE_ENG=llamacpp
INFERENCE_ENG_PORT=8080
INFERENCE_ENG_VERSION=server--b1-2321a5e
NUM_CPU_CORES=8.00
NUM_THREADS_COUNT=8
EMBEDDING_ENG=embedding_eng
EMBEDDING_ENG_PORT=8081
NUM_CPU_CORES_EMBEDDING=4.00
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ INFERENCE_ENG:=llamacpp
INFERENCE_ENG_PORT:=8080
INFERENCE_ENG_VERSION:=server--b1-2321a5e
NUM_CPU_CORES:=8.00
NUM_THREADS_COUNT:=8.00
NUM_THREADS_COUNT:=8

EMBEDDING_ENG:=embedding_eng
EMBEDDING_ENG_PORT:=8081
Expand Down
39 changes: 37 additions & 2 deletions actions/embedding.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// coding=utf-8

import { post } from "../tools/request.js";
import { extractAPIKeyFromRequest, validateAPIKey } from "../tools/apiKey.js";
import { getDatasetFromURL, loadDataset, parseDatasetWithoutVector } from "../database/rag-inference.js";

// Copyright [2024] [SkywardAI]
// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -25,7 +27,7 @@ export async function calculateEmbedding(content) {
}

export async function embeddings(req, res) {
if(!req.headers.authorization) {
if(!validateAPIKey(extractAPIKeyFromRequest(req))) {
res.status(401).send("Not Authorized!");
return;
}
Expand All @@ -35,7 +37,7 @@ export async function embeddings(req, res) {
res.status(400).send("Input sentence not specified!");
}

const embedding = calculateEmbedding(input);
const embedding = await calculateEmbedding(input);

if(!embedding) {
res.status(500).send("Embedding Engine Internal Server Error")
Expand All @@ -57,4 +59,37 @@ export async function embeddings(req, res) {
total_tokens: 0
}
})
}

/**
* function for upload dataset
* @param {Request} req
* @param {Response} res
*/
export async function uploadDataset(req, res) {
if(!validateAPIKey(extractAPIKeyFromRequest(req))) {
res.status(401).send("Not Authorized!");
return;
}

const { name, json, url, force, keep_records } = req.body;
if(!name || (!json && !url)) {
res.status(422).send("Please specify dataset name and one choice of json / url.");
return;
}

const loader = await loadDataset(name, force);
if(loader) {
try {
const dataset = url ?
await getDatasetFromURL(url) :
await parseDatasetWithoutVector(json);

await loader(dataset, force && keep_records);
} catch(error) {
res.status(500).send(error.message)
}
}

res.status(200).send("Dataset loaded");
}
12 changes: 2 additions & 10 deletions actions/inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { generateFingerprint } from "../tools/generator.js";
import { post } from "../tools/request.js";
import { searchByMessage } from "../database/rag-inference.js";
import { userMessageHandler } from "../tools/plugin.js";
import { extractAPIKeyFromHeader, validateAPIKey } from "../tools/apiKey.js";

/**
* Generates a response content object for chat completion.
Expand Down Expand Up @@ -91,18 +92,9 @@ async function doInference(req_body, callback, isStream) {
}
}

function validateAPIKey(api_key) {
// TODO: do something with api_key;
if(!api_key) return false;
if(+process.env.STATIC_API_KEY_ENABLED && process.env.STATIC_API_KEY) {
if(api_key !== process.env.STATIC_API_KEY) return false;
}
return true;
}

function retrieveData(req_header, req_body) {
// retrieve api key
const api_key = (req_header.authorization || '').split('Bearer ').pop();
const api_key = extractAPIKeyFromHeader(req_header);
if(!validateAPIKey(api_key)) {
return { error: true, status: 401, message: "Not Authorized" }
}
Expand Down
67 changes: 54 additions & 13 deletions database/rag-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
* @property {Float[]} vector Embedding from embedding engine, can get from calculateEmbedding();
*/

/**
* @typedef DatasetWithoutVectorStructure
* @property {String} context This is the context for AI to reference, been added into system instruction
* @property {String} identifier Identifier for the column, not necessarily unique
*/

/**
* Get a dataset from url
* @param {String} dataset_url
Expand All @@ -36,37 +42,72 @@ export async function getDatasetFromURL(dataset_url) {
const { rows, http_error } = await get('', {}, {URL: dataset_url});
if(http_error) return [];

return rows.map(({identifier, context, embedding})=>{
return rows.map(({row})=>{
const {identifier, context, embedding} = row;
return { identifier, context, vector: embedding }
})
}

/**
* Get a dataset from array
* @param {DatasetWithoutVectorStructure[]} dataset
* The dataset to be passed, in the format of `[{ identifier: "", context: "" },...]`
* @returns {Promise<DatasetStructure[]>} The dataset in {@link DatasetStructure}
*/
export async function parseDatasetWithoutVector(dataset) {
const parsed_dataset = [];

for(const data of dataset) {
const { context, identifier } = data;
const vector = await calculateEmbedding(context);
if(vector) {
parsed_dataset.push({
context, identifier, vector
})
}
}

return parsed_dataset;
}

/**
* Load a given dataset into database.
* If `force` specified, it will load the dataset without check whether it is already in system.
* If `force` specified, it will delete the old loaded dataset if applicable and load the dataset.
* @param {String} dataset_name The dataset name to load
* @param {Boolean} force Specify whether to force load the dataset, default `false`.
* @returns {Promise<Promise>} A function takes a dataset array, which should in the format of `[{identifier:"",context:"",vector:[...]}]`
* @returns {Promise<Promise|null>}
* If dataset is loaded and `force` not specified, this will return null.\
* Otherwise returns function takes parameters:
* * `dataset` - The dataset to be loaded, in the format of `[{identifier:"",context:"",vector:[...]}]`
* * `keep_records` - Optional, default `false`, set to `true` to prevent remove existed data from DB if `force` is `true`
*
* @example
* const loader = await loadDataset("<your-dataset-name>");
* const dataset = await getDatasetFromURL("<your-dataset-url>");
* await loader(dataset);
* if(loader) {
* const dataset = await getDatasetFromURL("<your-dataset-url>");
* await loader(dataset);
* }
*/
export async function loadDataset(dataset_name, force = false) {
const system_table = await getTable(SYSTEM_TABLE);
const dataset_table = await getTable(DATASET_TABLE);

const dataset_loaded = !!await system_table.query()
.where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray().length;
const dataset_loaded = !!(await system_table.query()
.where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray()).length;

return async function(dataset) {
if(!dataset_loaded || force) {
await dataset_table.add(dataset.map(({identifier, context, vector})=>{
return { identifier, context, vector, dataset_name }
}))
if(dataset_loaded && !force) return null;

return async function(dataset, keep_records = false) {
if(dataset_loaded && force && !keep_records) {
await dataset_table.delete(`dataset_name="${dataset_name}"`);
}

await dataset_table.add(
dataset.map(({identifier, context, vector})=>{
return { identifier, context, vector, dataset_name }
}
))

if(!dataset_loaded) {
await system_table.add([{title: "loaded_dataset_name", value: dataset_name}])
}
Expand Down Expand Up @@ -97,7 +138,7 @@ export async function searchByEmbedding(dataset_name, vector, max_distance = 0.8
).search(vector).distanceType("cosine").where(`dataset_name = "${dataset_name}"`)
.limit(max_results).toArray());

if(embedding_result) {
if(embedding_result && embedding_result.length) {
if(max_results === 1) {
const { identifier, context, _distance } = embedding_result.pop();
if(_distance >= max_distance) return null;
Expand Down
11 changes: 10 additions & 1 deletion generate_production_env.html
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@
AVAILABLE_APIS += '.';
AVAILABLE_APIS += +evt.target.available_api_token.checked;
AVAILABLE_APIS += '.';
AVAILABLE_APIS += +evt.target.available_api_embedding.checked;
AVAILABLE_APIS += +evt.target.available_api_embedding_index.checked;
AVAILABLE_APIS += +evt.target.available_api_embedding_dataset.checked;
AVAILABLE_APIS += '.';
AVAILABLE_APIS += +evt.target.available_api_version.checked;
}
Expand Down Expand Up @@ -349,6 +350,14 @@ <h3>APIs Available</h3>
<div class="title">Enable Embedding API</div>
<input type="checkbox" name="available_api_embedding">
</div>
<div class="one-line api sub-api">
<div class="title">Enable Calculate Embeddings</div>
<input type="checkbox" name="available_api_embedding_index">
</div>
<div class="one-line api sub-api">
<div class="title">Enable Dataset Upload</div>
<input type="checkbox" name="available_api_embedding_dataset">
</div>
<div class="one-line api api-index">
<div class="title">Enable Version API</div>
<input type="checkbox" name="available_api_version">
Expand Down
3 changes: 2 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ decodeEnabledAPIs();
const force_load = false;
await initDB(force_load)
if(+process.env.LOAD_DEFAULT_DATASET) {
await (await loadDataset(process.env.DEFAULT_DATASET_NAME || "production_dataset", force_load))(await loadDefaultDataset())
const loader = await loadDataset(process.env.DEFAULT_DATASET_NAME || "production_dataset", force_load)
loader && await loader(await loadDefaultDataset())
}

const app = express();
Expand Down
6 changes: 4 additions & 2 deletions routes/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
// limitations under the License.

import { Router } from "express";
import { embeddings } from "../actions/embedding.js";
import { embeddings, uploadDataset } from "../actions/embedding.js";
import { isRouteEnabled } from "../tools/enabledApiDecoder.js";

export default function embeddingRoute() {
const router = Router();

router.post("/", embeddings);
isRouteEnabled("embedding", "index") && router.post("/", embeddings);
isRouteEnabled("embedding", "dataset") && router.post("/dataset", uploadDataset);

return router;
}
Loading

0 comments on commit 5a67ed8

Please sign in to comment.