Skip to content

Commit

Permalink
choose 🤖
Browse files Browse the repository at this point in the history
  • Loading branch information
cboettig committed Jan 9, 2025
1 parent 60f152e commit 16c3919
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 36 deletions.
76 changes: 45 additions & 31 deletions app.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ui <- page_sidebar(
layout_columns(
textInput("chat",
label = NULL,
"Which counties in California have the highest average social vulnerability?",
"Which four counties in California have the highest average social vulnerability?",
width = "100%"),
div(
actionButton("user_msg", "", icon = icon("paper-plane"),
Expand All @@ -44,7 +44,7 @@ ui <- page_sidebar(
col_widths = c(11, 1)),
fill = FALSE
),

textOutput("agent"),


Expand All @@ -55,14 +55,14 @@ ui <- page_sidebar(
plotOutput("chart2"),
),
col_widths = c(8, 4),
row_heights = c("600px"),
max_height = "700px"
row_heights = c("500px"),
max_height = "600px"
),

gt_output("table"),

card(fill = TRUE,
card_header(fa("robot")),
card_header(fa("robot"), textOutput("model", inline = TRUE)),
accordion(
open = FALSE,
accordion_panel(
Expand All @@ -76,13 +76,21 @@ ui <- page_sidebar(
textOutput("explanation"),
)
),
card(
),
card(
card_header("Errata"),
shiny::markdown(readr::read_file("footer.md")),
)
),

sidebar = sidebar(
selectInput(
"select",
"Select an LLM:",
list("LLama3" = "llama3",
#"OLMO2 (AllenAI)" = "olmo",
"Gorilla (UC Berkeley)" = "gorilla"
)
),

input_switch("redlines", "Redlined Areas", value = FALSE),
input_switch("svi", "Social Vulnerability", value = TRUE),
input_switch("richness", "Biodiversity Richness", value = FALSE),
Expand All @@ -99,21 +107,15 @@ ui <- page_sidebar(


repo <- "https://data.source.coop/cboettig/social-vulnerability"
pmtiles <- glue("{repo}/svi2020_us_tract.pmtiles")
parquet <- glue("{repo}/svi2020_us_tract.parquet")
pmtiles <- glue("{repo}/2022/SVI2022_US_tract.pmtiles")
parquet <- glue("{repo}/2022/SVI2022_US_tract.parquet")
con <- duckdbfs::cached_connection()
svi <- open_dataset(parquet, tblname = "svi") |> filter(RPL_THEMES > 0)
schema <- read_file("schema.yml")
system_prompt <- glue::glue(readr::read_file("system-prompt.md"),
.open = "<", .close = ">")

chat <- ellmer::chat_vllm(
base_url = "https://llm.nrp-nautilus.io/",
model = "llama3",
api_key = Sys.getenv("NRP_API_KEY"),
system_prompt = system_prompt,
api_args = list(temperature = 0)
)

safe_parse <- function(txt) {
gsub("[\r\n]", " ", txt) |> gsub("\\s+", " ", x = _)
}


# helper utilities
# faster/more scalable to pass maplibre the ids to refilter pmtiles,
Expand All @@ -140,17 +142,30 @@ server <- function(input, output, session) {
chart1 <- chart1_data |>
ggplot(aes(mean_svi)) + geom_density(fill="darkred") +
ggtitle("County-level vulnerability nation-wide")

data <- reactiveValues(df = tibble())
output$chart1 <- renderPlot(chart1)

model <- reactive(input$select)
output$model <- renderText(input$select)
observe({
schema <- read_file("schema.yml")
system_prompt <- glue::glue(readr::read_file("system-prompt.md"),
.open = "<", .close = ">")
chat <- ellmer::chat_vllm(
base_url = "https://llm.nrp-nautilus.io/",
model = model(),
api_key = Sys.getenv("NRP_API_KEY"),
system_prompt = system_prompt,
api_args = list(temperature = 0)
)

observeEvent(input$user_msg, {
stream <- chat$chat(input$chat)



# Parse response
response <- jsonlite::fromJSON(stream)
response <- jsonlite::fromJSON(safe_parse(stream))
#response <- jsonlite::fromJSON(stream)

if ("query" %in% names(response)) {
output$sql_code <- renderText(stringr::str_wrap(response$query, width = 60))
Expand Down Expand Up @@ -187,12 +202,12 @@ server <- function(input, output, session) {
}

})

})


output$map <- renderMaplibre({

m <- maplibre(center = c(-92.9, 41.3), zoom = 3, height = "400")
m <- maplibre(center = c(-104.9, 40.3), zoom = 3, height = "400")
if (input$redlines) {
m <- m |>
add_fill_layer(
Expand Down Expand Up @@ -230,7 +245,7 @@ server <- function(input, output, session) {
id = "svi_layer",
source = list(type = "vector",
url = paste0("pmtiles://", pmtiles)),
source_layer = "SVI2000_US_tract",
source_layer = "svi",
filter = filter_column(svi, data$df, "FIPS"),
fill_opacity = 0.5,
fill_color = interpolate(column = "RPL_THEMES",
Expand All @@ -239,9 +254,8 @@ server <- function(input, output, session) {
na_color = "lightgrey")
)
}
m})


m
})

}

Expand Down
8 changes: 4 additions & 4 deletions schema.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
- VARIABLE_NAME: ST
DESCRIPTION: State-level FIPS code (two-digit integer)
DESCRIPTION: INTEGER State-level FIPS code (two-digit integer)
- VARIABLE_NAME: STATE
DESCRIPTION: State name
- VARIABLE_NAME: ST_ABBR
DESCRIPTION: State abbreviation
DESCRIPTION: State abbreviation, two-letter string
- VARIABLE_NAME: STCNTY
DESCRIPTION: County-level FIPS code (5 digit integer)
DESCRIPTION: INTEGER County-level FIPS code (5 digit integer)
- VARIABLE_NAME: COUNTY
DESCRIPTION: County name
- VARIABLE_NAME: FIPS
DESCRIPTION: Tract-level geographic identification (full Census Bureau FIPS code)
DESCRIPTION: INTEGER, Tract-level geographic identification (full Census Bureau FIPS code)
- VARIABLE_NAME: LOCATION
DESCRIPTION: Text description of tract county state
- VARIABLE_NAME: AREA_SQMI
Expand Down
4 changes: 3 additions & 1 deletion system-prompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ Include semantically meaningful columns like COUNTY and STATE name.
If your answer involves the construction of a SQL query, you must format your answer as follows:

{
"query": "your raw SQL response goes here",
"query": "your raw SQL response goes here.",
"explanation": "your explanation of the query"
}

Think carefully about your SQL query, keep it concise and ensure it is entirely valid SQL syntax.

If your answer does not involve a SQL query, please reply with the following format instead:

{
Expand Down

0 comments on commit 16c3919

Please sign in to comment.