-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(magic): add run_personal and run_shared magic commands (#50)
* feat(magic): add run_personal and run_shared magic commands * feat(magic): add support for jinja2 templates * refactor(magic): create singlestoredb.magics module * feat(magic): use tempfile.TemporaryDirectory for downloaded file
- Loading branch information
1 parent
1b6131a
commit 809f513
Showing
3 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from IPython.core.interactiveshell import InteractiveShell | ||
|
||
from .run_personal import RunPersonalMagic | ||
from .run_shared import RunSharedMagic | ||
|
||
# In order to actually use these magics, we must register them with a | ||
# running IPython. | ||
|
||
|
||
def load_ipython_extension(ip: InteractiveShell) -> None: | ||
""" | ||
Any module file that define a function named `load_ipython_extension` | ||
can be loaded via `%load_ext module.path` or be configured to be | ||
autoloaded by IPython at startup time. | ||
""" | ||
|
||
# Load jupysql extension | ||
# This is necessary for jupysql to initialize internal state | ||
# required to render messages | ||
assert ip.extension_manager is not None | ||
result = ip.extension_manager.load_extension('sql') | ||
if result == 'no load function': | ||
raise RuntimeError('Could not load sql extension. Is jupysql installed?') | ||
|
||
# Check if %run magic command is defined | ||
if ip.find_line_magic('run') is None: | ||
raise RuntimeError( | ||
'%run magic command is not defined. ' | ||
'Is it available in your IPython environment?', | ||
) | ||
|
||
# Register run_personal and run_shared | ||
ip.register_magics(RunPersonalMagic(ip)) | ||
ip.register_magics(RunSharedMagic(ip)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import tempfile | ||
from typing import Any | ||
|
||
from IPython.core.interactiveshell import InteractiveShell | ||
from IPython.core.magic import line_magic | ||
from IPython.core.magic import Magics | ||
from IPython.core.magic import magics_class | ||
from IPython.core.magic import needs_local_scope | ||
from IPython.core.magic import no_var_expand | ||
from jinja2 import Template | ||
|
||
|
||
@magics_class | ||
class RunPersonalMagic(Magics): | ||
def __init__(self, shell: InteractiveShell): | ||
Magics.__init__(self, shell=shell) | ||
|
||
@no_var_expand | ||
@needs_local_scope | ||
@line_magic('run_personal') | ||
def run_personal(self, line: str, local_ns: Any = None) -> Any: | ||
""" | ||
Downloads a personal file using the %sql magic and then runs it using %run. | ||
Examples:: | ||
# Line usage | ||
%run_personal personal_file.ipynb | ||
%run_personal {{ sample_notebook_name }} | ||
""" | ||
|
||
template = Template(line.strip()) | ||
personal_file = template.render(local_ns) | ||
if not personal_file: | ||
raise ValueError('No personal file specified.') | ||
if (personal_file.startswith("'") and personal_file.endswith("'")) or \ | ||
(personal_file.startswith('"') and personal_file.endswith('"')): | ||
personal_file = personal_file[1:-1] | ||
if not personal_file: | ||
raise ValueError('No personal file specified.') | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_file_path = os.path.join(temp_dir, personal_file) | ||
sql_command = ( | ||
f"DOWNLOAD PERSONAL FILE '{personal_file}' " | ||
f"TO '{temp_file_path}'" | ||
) | ||
|
||
# Execute the SQL command | ||
self.shell.run_line_magic('sql', sql_command) | ||
# Run the downloaded file | ||
self.shell.run_line_magic('run', f'"{temp_file_path}"') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import tempfile | ||
from typing import Any | ||
|
||
from IPython.core.interactiveshell import InteractiveShell | ||
from IPython.core.magic import line_magic | ||
from IPython.core.magic import Magics | ||
from IPython.core.magic import magics_class | ||
from IPython.core.magic import needs_local_scope | ||
from IPython.core.magic import no_var_expand | ||
from jinja2 import Template | ||
|
||
|
||
@magics_class | ||
class RunSharedMagic(Magics): | ||
def __init__(self, shell: InteractiveShell): | ||
Magics.__init__(self, shell=shell) | ||
|
||
@no_var_expand | ||
@needs_local_scope | ||
@line_magic('run_shared') | ||
def run_shared(self, line: str, local_ns: Any = None) -> Any: | ||
""" | ||
Downloads a shared file using the %sql magic and then runs it using %run. | ||
Examples:: | ||
# Line usage | ||
%run_shared shared_file.ipynb | ||
%run_shared {{ sample_notebook_name }} | ||
""" | ||
|
||
template = Template(line.strip()) | ||
shared_file = template.render(local_ns) | ||
if not shared_file: | ||
raise ValueError('No shared file specified.') | ||
if (shared_file.startswith("'") and shared_file.endswith("'")) or \ | ||
(shared_file.startswith('"') and shared_file.endswith('"')): | ||
shared_file = shared_file[1:-1] | ||
if not shared_file: | ||
raise ValueError('No personal file specified.') | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
temp_file_path = os.path.join(temp_dir, shared_file) | ||
sql_command = f"DOWNLOAD SHARED FILE '{shared_file}' TO '{temp_file_path}'" | ||
|
||
# Execute the SQL command | ||
self.shell.run_line_magic('sql', sql_command) | ||
# Run the downloaded file | ||
self.shell.run_line_magic('run', f'"{temp_file_path}"') |