Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fetching models from example-models repo #919

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions hls4ml/utils/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

from .config import create_config

ORGANIZATION = 'fastmachinelearning'
BRANCH = 'master'


def _load_data_config_avai(model_name):
"""
Check data and configuration availability for each model from this file:

https://github.com/hls-fpga-machine-learning/example-models/blob/master/available_data_config.json
https://github.com/fastmachinelearning/example-models/blob/master/available_data_config.json
jmitrevs marked this conversation as resolved.
Show resolved Hide resolved
"""

link_to_list = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_data_config.json'
)
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_data_config.json'

temp_file, _ = urlretrieve(link_to_list)

Expand Down Expand Up @@ -73,12 +74,8 @@ def _load_example_data(model_name):
input_file_name = filtered_name + "_input.dat"
output_file_name = filtered_name + "_output.dat"

link_to_input = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + input_file_name
)
link_to_output = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + output_file_name
)
link_to_input = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + input_file_name
link_to_output = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + output_file_name

urlretrieve(link_to_input, input_file_name)
urlretrieve(link_to_output, output_file_name)
Expand All @@ -91,9 +88,7 @@ def _load_example_config(model_name):

config_name = filtered_name + "_config.yml"

link_to_config = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/config-files/' + config_name
)
link_to_config = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/config-files/' + config_name

# Load the configuration as dictionary from file
urlretrieve(link_to_config, config_name)
Expand All @@ -110,7 +105,7 @@ def fetch_example_model(model_name, backend='Vivado'):
Download an example model (and example data & configuration if available) from github repo to working directory,
and return the corresponding configuration:

https://github.com/hls-fpga-machine-learning/example-models
https://github.com/fastmachinelearning/example-models

Use fetch_example_list() to see all the available models.

Expand All @@ -122,15 +117,18 @@ def fetch_example_model(model_name, backend='Vivado'):
dict: Dictionary that stores the configuration to the model
"""

# Initilize the download link and model type
download_link = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/'
# Initialize the download link and model type
download_link = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/'
model_type = None
model_config = None

# Check for model's type to update link
if '.json' in model_name:
model_type = 'keras'
model_config = 'KerasJson'
elif '.h5' in model_name:
model_type = 'keras'
model_config = 'KerasH5'
elif '.pt' in model_name:
model_type = 'pytorch'
model_config = 'PytorchModel'
Expand Down Expand Up @@ -158,11 +156,12 @@ def fetch_example_model(model_name, backend='Vivado'):

if _config_is_available(model_name):
config = _load_example_config(model_name)
config[model_config] = model_name # Ensure that paths are correct
else:
config = _create_default_config(model_name, model_config, backend)

# If the model is a keras model then have to download its weight file as well
if model_type == 'keras':
if model_type == 'keras' and '.json' in model_name:
model_weight_name = model_name[:-5] + "_weights.h5"

download_link_weight = download_link + model_type + '/' + model_weight_name
Expand All @@ -174,7 +173,7 @@ def fetch_example_model(model_name, backend='Vivado'):


def fetch_example_list():
link_to_list = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_models.json'
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_models.json'

temp_file, _ = urlretrieve(link_to_list)

Expand Down
32 changes: 32 additions & 0 deletions test/pytest/test_fetch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import ast
import io
from contextlib import redirect_stdout
from pathlib import Path

import pytest

import hls4ml

test_root_path = Path(__file__).parent


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
def test_fetch_example_utils(backend):
f = io.StringIO()
with redirect_stdout(f):
hls4ml.utils.fetch_example_list()
out = f.getvalue()

model_list = ast.literal_eval(out) # Check if we indeed got a dictionary back

assert 'qkeras_mnist_cnn.json' in model_list['keras']

# This model has an example config that is also downloaded. Stored configurations don't set "Backend" value.
config = hls4ml.utils.fetch_example_model('qkeras_mnist_cnn.json', backend=backend)
config['KerasJson'] = 'qkeras_mnist_cnn.json'
config['KerasH5']
config['Backend'] = backend
config['OutputDir'] = str(test_root_path / f'hls4mlprj_fetch_example_{backend}')

hls_model = hls4ml.converters.keras_to_hls(config)
hls_model.compile() # For now, it is enough if it compiles, we're only testing downloading works as expected
Loading