Skip to content

Commit

Permalink
Merge pull request #4 from seantis/feature/ai-assistant
Browse files Browse the repository at this point in the history
Feature/ai assistant
  • Loading branch information
somehowchris authored Jun 11, 2024
2 parents 82d689d + 6ea2039 commit 4f772c8
Show file tree
Hide file tree
Showing 22 changed files with 803 additions and 285 deletions.
3 changes: 3 additions & 0 deletions development.ini.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pyramid.available_languages =

sqlalchemy.url = sqlite:///%(here)s/riskmatrix.sqlite

openai_api_key=
anthropic_api_key=

session.type = file
session.data_dir = %(here)s/data/sessions/data
session.lock_dir = %(here)s/data/sessions/lock
Expand Down
22 changes: 22 additions & 0 deletions src/riskmatrix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from riskmatrix.route_factories import root_factory
from riskmatrix.security import authenticated_user
from riskmatrix.security_policy import SessionSecurityPolicy
from openai import OpenAI
from anthropic import Anthropic


from typing import TYPE_CHECKING
Expand Down Expand Up @@ -71,5 +73,25 @@ def main(
with Configurator(settings=settings, root_factory=root_factory) as config:
includeme(config)

if openai_apikey := settings.get('openai_api_key'):

openai_client = OpenAI(
api_key=openai_apikey
)
config.add_request_method(
lambda r: openai_client,
'openai',
reify=True
)
if anthropic_apikey := settings.get('anthropic_api_key'):
anthropic_client = Anthropic(
api_key=anthropic_apikey
)
config.add_request_method(
lambda r: anthropic_client,
'anthropic',
reify=True
)

app = config.make_wsgi_app()
return Fanstatic(app, versioning=True)
1 change: 0 additions & 1 deletion src/riskmatrix/data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def cell(self, data: Any) -> str:
params = {}
if 'class_name' in self.options:
params['class'] = self.options['class_name']

if callable(self.sort_key):
params['data_order'] = self.sort_key(data)
return f'<td {html_params(**params)}>{self.format_data(data)}</td>'
Expand Down
3 changes: 3 additions & 0 deletions src/riskmatrix/layouts/layout.pt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
<script type="text/javascript" src="${layout.static_url('riskmatrix:static/js/bundle.min.js')}"></script>
<script type="text/javascript" src="${layout.static_url('riskmatrix:static/js/sentry.js')}"></script>
</tal:block>
<script type="text/javascript" src="${layout.static_url('riskmatrix:static/js/plotly.min.js')}"></script>
<script type="text/javascript" src="${layout.static_url('riskmatrix:static/js/marked.min.js')}"></script>

<title>RiskMatrix<tal:b tal:condition="exists:title"> — ${title}</tal:b></title>

</head>
Expand Down
3 changes: 2 additions & 1 deletion src/riskmatrix/layouts/navbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def navbar(context: object, request: 'IRequest') -> 'RenderData':
NavbarEntry(
request,
_('Risk Catalog'),
request.route_url('risk_catalog')
request.route_url('risk_catalog'),
lambda request, url: request.path_url.startswith(request.route_url('risk_catalog'))
),
NavbarEntry(
request,
Expand Down
3 changes: 2 additions & 1 deletion src/riskmatrix/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .asset import Asset
from .organization import Organization
from .risk import Risk
from .risk_assessment import RiskAssessment
from .risk_assessment import RiskAssessment, RiskMatrixAssessment
from .risk_catalog import RiskCatalog
from .risk_category import RiskCategory
from .user import User
Expand Down Expand Up @@ -60,5 +60,6 @@ def includeme(config: 'Configurator') -> None:
'RiskAssessment',
'RiskCatalog',
'RiskCategory',
'RiskMatrixAssessment',
'User'
)
4 changes: 4 additions & 0 deletions src/riskmatrix/models/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Mapped
from uuid import uuid4
from riskmatrix.models.risk_catalog import RiskCatalog

from riskmatrix.orm.meta import Base
from riskmatrix.orm.meta import str_256
Expand Down Expand Up @@ -32,6 +33,9 @@ class Organization(Base):
risks: Mapped[list['Risk']] = relationship(
back_populates='organization',
)
risk_catalogs: Mapped[list['RiskCatalog']] = relationship(
back_populates='organization',
)
users: Mapped[list['User']] = relationship(
back_populates='organization',
)
Expand Down
6 changes: 5 additions & 1 deletion src/riskmatrix/models/risk_assessment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from riskmatrix.orm.meta import UUIDStrPK


from typing import Any
from typing import Any, ClassVar
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from riskmatrix.types import ACL
Expand Down Expand Up @@ -146,3 +146,7 @@ def __acl__(self) -> list['ACL']:
return [
(Allow, f'org_{self.risk.organization_id}', ['view']),
]


class RiskMatrixAssessment(RiskAssessment):
nr: ClassVar[int]
2 changes: 1 addition & 1 deletion src/riskmatrix/models/risk_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class RiskCatalog(Base):
modified: Mapped[datetime | None] = mapped_column(onupdate=utcnow)

risks: Mapped[list['Risk']] = relationship(back_populates='catalog')
organization: Mapped['Organization'] = relationship()
organization: Mapped['Organization'] = relationship(back_populates='risk_catalogs')

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/riskmatrix/static/css/bootstrap.min.css

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/riskmatrix/static/css/bootstrap.min.css.map

Large diffs are not rendered by default.

44 changes: 39 additions & 5 deletions src/riskmatrix/static/css/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ iframe.width-changed {
}

.is-invalid ~ .editor-toolbar {
border-color: #dc3545;
border-color: #DD1122;
padding-right: calc(1.5em + .75rem);
background-image: url("data:image/svg+xml,%3csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 12 12' width='12' height='12' fill='none' stroke='%23dc3545'%3e%3ccircle cx='6' cy='6' r='4.5'/%3e%3cpath stroke-linejoin='round' d='M5.8 3.6h.4L6 6.5z'/%3e%3ccircle cx='6' cy='8.2' r='.6' fill='%23dc3545' stroke='none'/%3e%3c/svg%3e");
background-repeat: no-repeat;
Expand All @@ -235,7 +235,7 @@ iframe.width-changed {
}

.is-invalid ~ .CodeMirror {
border-color: #dc3545;
border-color: #DD1122;
}

.is-invalid ~ .editor-toolbar + .CodeMirror {
Expand All @@ -259,15 +259,15 @@ h1 {margin-bottom: .6em}
.matrix th { width: 5%; text-align: center;}
.matrix td { width: 15%; }

.matrix .low {
.matrix.low {
background-color: #98FB98; /* Pale Green */
}

.matrix .medium {
.matrix.medium {
background-color: #FFFFA1; /* Light Yellow */
}

.matrix .high {
.matrix.high {
background-color: #FFA07A; /* Light Red */
}

Expand All @@ -282,4 +282,38 @@ h1 {margin-bottom: .6em}
-webkit-column-count: 1; /* For Safari and Chrome */
-moz-column-count: 1; /* For Firefox */
column-count: 1; /* Standard syntax */
}

.plotly-graph-div {
margin-left: auto;
margin-right: auto;
}

.suggestion-container {
display: flex;
flex-wrap: wrap;
gap: 12px;
}

.suggestion-box {
border: 1px solid black;
padding: 10px;
width: calc(50% - 12px); /* Adjust width to account for the gap */
border-radius: 6px;
}

.suggestion-box h2 {
font-weight: bold;
font-size: 1.2em;
}


#generate-risks-xhr ul {
padding-left: 12px;
list-style-type: none!important;

}

#generate-risks-xhr li::marker {
list-style-type: none!important;
}
6 changes: 6 additions & 0 deletions src/riskmatrix/static/js/marked.min.js

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/riskmatrix/static/js/plotly.min.js

Large diffs are not rendered by default.

145 changes: 145 additions & 0 deletions src/riskmatrix/static/js/xhr_edit.js
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,148 @@ $(function() {
return false;
});
});

$(document).ready(function () {
marked.use({
hooks: {
postprocess: function (html) {
return html.replace(/\sdisabled=""/g, '');
}
}
});

var risks_modal = $("form#generate-risks-xhr-form");

if (risks_modal.length === 0) return;
risks_modal = risks_modal[0];
$('div.modal#generate-risks-xhr').modal('show');

$('div.modal#generate-risks-xhr').on('hidden.bs.modal', function (e) {
window.location.href = '/risk_catalog';
})

var answers = JSON.parse(risks_modal.dataset.answers);
var catalogs = JSON.parse(risks_modal.dataset.catalogs);
var idx = 0;
var csrf_token = $("#generate-risks")[0].dataset['csrfToken'];
var title = $("h5#generate-risks-xhr-title").first();
var save_button = $("button#generate-risks")[0];
title.text(catalogs[idx].title);

function initiateGeneration(catalog) {
console.log(catalog)
console.log("Fetching data...");
title.text("Generating risks for '" + catalog.name + "' catalog..");
// Initialize a variable to accumulate received text
let accumulatedText = '';

// Function to update the modal body with new text
function updateModalBody(newText) {
// Remove all contents from div.modal-body
$("#generate-risks-xhr div.modal-body").empty();
const htmlContent = marked.parse(newText, { gfm: true });
// Insert the parsed HTML into div.modal-body
$("#generate-risks-xhr div.modal-body").html(htmlContent);
}

// Using the Fetch API to handle the streaming response
fetch('/risk_catalog/generate/stream', {
method: 'POST',
headers: {
'X-CSRF-Token': csrf_token,
'X-Requested-With': 'XMLHttpRequest' // Mark the request as an AJAX request
},
body: JSON.stringify({
answers,
catalog
}),
}).then(response => {
updateModalBody('Awaiting magician response...')
save_button.disabled = true;
const reader = response.body.getReader();

// Function to process the stream
(async function readStream() {
while (true) {
const { done, value } = await reader.read();
if (done) break;
let textChunk = new TextDecoder("utf-8").decode(value);
accumulatedText += textChunk; // Accumulate the new text chunk
updateModalBody(accumulatedText); // Update the modal body with the new accumulated text
}
save_button.disabled = false;
title.text("Generated Risks for '" + catalog.name + "' catalog");
})();
}).catch(error => {
console.error("Error fetching data:", error);
});



}
$("button#generate-risks").on('click', function (event) {
event.preventDefault();
// Array to hold the objects
var risks = [];

$(this).disabled = true;

// Iterate over each list item
$('#generate-risks-xhr div.modal-body ul > li').each(function () {
// For each 'li', find the 'input' (checkbox) and check its checked status
var isChecked = $(this).find('input[type="checkbox"]').is(':checked');

// Extract the risk name from the 'strong' element
var name = $(this).find('strong').text();

// Extract the description by getting the entire text of 'li'
// and then removing the name (including the following colon and space).
var description = $(this).text().replace(name + ': ', '');
if (isChecked) {
// Construct the risk object and add it to the 'risks' array
risks.push({
name: name,
description: description,
catalog: catalogs[idx],
});
}
});


Promise.all(risks.map(risk => {


// make post request to /risk_catalog/{id}/add
return fetch('/risks_catalog/' + catalogs[idx].id + '/add', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRF-Token': csrf_token,
'X-Requested-With': 'XMLHttpRequest' // Mark the request as an AJAX request
},
body: JSON.stringify(risk),
}).then(response => {
return response.json();
}).then(data => {
risks = [];
}).catch(error => {
console.error("Error fetching data:", error);
}).finally(() => {
});
})).then(() => {
idx += 1;
if (idx >= catalogs.length) {
console.log("No more catalogs to process");
$('div.modal#generate-risks-xhr').modal('hide');
return;
}
initiateGeneration(catalogs[idx]);
})



});


initiateGeneration(catalogs[idx]);
});
9 changes: 7 additions & 2 deletions src/riskmatrix/subscribers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pyramid.events import NewRequest
from pyramid.events import NewResponse

import secrets

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand All @@ -19,7 +19,7 @@ def default_csp_directives(request: 'IRequest') -> dict[str, str]:
"frame-ancestors": "'none'",
"img-src": "'self' data: blob:",
"object-src": "'self'",
"script-src": "'self' blob: resource:",
"script-src": f"'self' 'nonce-{request.csp_nonce}' blob: resource:",
"style-src": "'self' 'unsafe-inline'",
}

Expand Down Expand Up @@ -50,7 +50,12 @@ def sentry_context(event: NewRequest) -> None:
with configure_scope() as scope:
scope.user = {'id': request.user.id}

def request_none_generator(event: 'NewRequest') -> None:
request = event.request
request.set_property(lambda r: secrets.token_urlsafe(), 'csp_nonce', reify=True)


def includeme(config: 'Configurator') -> None:
config.add_subscriber(csp_header, NewResponse)
config.add_subscriber(request_none_generator, NewRequest)
config.add_subscriber(sentry_context, NewRequest)
Loading

0 comments on commit 4f772c8

Please sign in to comment.