-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto Cache Plugin #2971
base: master
Are you sure you want to change the base?
Auto Cache Plugin #2971
Conversation
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
flytekit/core/auto_cache.py
Outdated
self.cache_serialize = cache_serialize | ||
self.cache_version = cache_version | ||
self.cache_ignore_input_vars = cache_ignore_input_vars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the purpose of saving this state here? aren't these just forwarded to the underlying TaskMetadata
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea with this is the user could use the CachePolicy to define all the arguments relating to caching. This simplifies the UX a bit as opposed to having a CachePolicy and a cache_ignore_input_vars, cache_serialize, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a little confusing:
cache_version
should not be exposed, since theAutoCache
protocol is meant to produce this value automatically, andsalt
is meant to fulfill the need of manually bumping the cache.- I think it makes sense to keep
cache_serialize
andcache_ignore_input_vars
as options to specify in the@task
decorator as opposed to introducing this redundancy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Sounds like there is a separate effort aimed at collecting all of the caching arguments here: flyteorg/flyte#6143
Happy to use that instead and simplify the arguments here!
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2971 +/- ##
==========================================
+ Coverage 76.49% 77.96% +1.46%
==========================================
Files 200 202 +2
Lines 20901 21324 +423
Branches 2689 2739 +50
==========================================
+ Hits 15989 16625 +636
+ Misses 4195 3904 -291
- Partials 717 795 +78 ☔ View full report in Codecov by Sentry. |
Code Review Agent Run #bc105bActionable Suggestions - 9
Additional Suggestions - 7
Review Details
|
Changelist by BitoThis pull request implements the following key changes.
|
flytekit/core/auto_cache.py
Outdated
hash_obj = hashlib.sha256(task_hash.encode()) | ||
return hash_obj.hexdigest() | ||
|
||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider returning an empty string instead of None
for consistency in return types. The method signature indicates it returns str
but can return None
.
Code suggestion
Check the AI-generated fix before applying
return None | |
return "" |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
numpy==1.24.3 | ||
pandas==2.0.3 | ||
requests==2.31.0 | ||
matplotlib==3.7.2 | ||
pillow==10.0.0 | ||
scipy==1.11.2 | ||
pytest==7.4.0 | ||
urllib3==2.0.4 | ||
cryptography==41.0.3 | ||
setuptools==68.0.0 | ||
flask==2.3.2 | ||
django==4.2.4 | ||
scikit-learn==1.3.0 | ||
beautifulsoup4==4.12.2 | ||
pyyaml==6.0 | ||
fastapi==0.100.0 | ||
sqlalchemy==2.0.36 | ||
tqdm==4.65.0 | ||
pytest-mock==3.11.0 | ||
jinja2==3.1.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider pinning dependencies to compatible versions using ~=
or >=
instead of ==
to allow for minor version updates that include security patches while maintaining compatibility. This helps keep dependencies up-to-date with security fixes.
Code suggestion
Check the AI-generated fix before applying
numpy==1.24.3 | |
pandas==2.0.3 | |
requests==2.31.0 | |
matplotlib==3.7.2 | |
pillow==10.0.0 | |
scipy==1.11.2 | |
pytest==7.4.0 | |
urllib3==2.0.4 | |
cryptography==41.0.3 | |
setuptools==68.0.0 | |
flask==2.3.2 | |
django==4.2.4 | |
scikit-learn==1.3.0 | |
beautifulsoup4==4.12.2 | |
pyyaml==6.0 | |
fastapi==0.100.0 | |
sqlalchemy==2.0.36 | |
tqdm==4.65.0 | |
pytest-mock==3.11.0 | |
jinja2==3.1.2 | |
numpy~=1.24.3 | |
pandas~=2.0.3 | |
requests~=2.31.0 | |
matplotlib~=3.7.2 | |
pillow~=10.0.0 | |
scipy~=1.11.2 | |
pytest~=7.4.0 | |
urllib3~=2.0.4 | |
cryptography~=41.0.3 | |
setuptools~=68.0.0 | |
flask~=2.3.2 | |
django~=4.2.4 | |
scikit-learn~=1.3.0 | |
beautifulsoup4~=4.12.2 | |
pyyaml~=6.0 | |
fastapi~=0.100.0 | |
sqlalchemy~=2.0.36 | |
tqdm~=4.65.0 | |
pytest-mock~=3.11.0 | |
jinja2~=3.1.2 |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
packages = cache.get_version_dict().keys() | ||
|
||
expected_packages = {'PIL', 'bs4', 'numpy', 'pandas', 'scipy', 'sklearn'} | ||
set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion statement appears to be missing the assert
keyword, which means this comparison won't actually validate the test condition. Consider adding the assert
keyword.
Code suggestion
Check the AI-generated fix before applying
set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" | |
assert set(packages) == expected_packages, f"Expected keys {expected_packages}, but got {set(packages)}" |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
|
||
Returns: | ||
Set[Callable[..., Any]]: A set of all dependencies found. | ||
""" | ||
|
||
dependencies = set() | ||
source = textwrap.dedent(inspect.getsource(func)) | ||
parsed_ast = ast.parse(source) | ||
|
||
# Initialize a dictionary to mimic the function's global namespace for locally defined imports | ||
locals_dict = {} | ||
# Initialize a dictionary to hold constant imports and class attributes | ||
constant_imports = {} | ||
# If class attributes are provided, include them in the constant imports | ||
if class_attributes: | ||
constant_imports.update(class_attributes) | ||
|
||
# Check each function call in the AST | ||
for node in ast.walk(parsed_ast): | ||
if isinstance(node, ast.Import): | ||
# For each alias in the import statement, we import the module and add it to the locals_dict. | ||
# This is because the module itself is being imported, not a specific attribute or function. | ||
for alias in node.names: | ||
module = importlib.import_module(alias.name) | ||
locals_dict[self._get_alias_name(alias)] = module | ||
# We then get all the literal constants defined in the module's __init__.py file. | ||
# These constants are later checked for usage within the function. | ||
module_constants = self.get_module_literal_constants(module) | ||
constant_imports.update( | ||
{f"{self._get_alias_name(alias)}.{name}": value for name, value in module_constants.items()} | ||
) | ||
elif isinstance(node, ast.ImportFrom): | ||
module_name = node.module | ||
module = importlib.import_module(module_name) | ||
for alias in node.names: | ||
# Attempt to resolve the imported object directly from the module | ||
imported_obj = getattr(module, alias.name, None) | ||
if imported_obj: | ||
# If the object is found directly in the module, add it to the locals_dict | ||
locals_dict[self._get_alias_name(alias)] = imported_obj | ||
# Check if the imported object is a literal constant and add it to constant_imports if so | ||
if self.is_literal_constant(imported_obj): | ||
constant_imports.update({f"{self._get_alias_name(alias)}": imported_obj}) | ||
else: | ||
# If the object is not found directly in the module, attempt to import it as a submodule | ||
# This is necessary for cases like `from PIL import Image`, where Image is not imported in PIL's __init__.py | ||
# PIL and similar packages use different mechanisms to expose their objects, requiring this fallback approach | ||
submodule = importlib.import_module(f"{module_name}.{alias.name}") | ||
imported_obj = getattr(submodule, alias.name, None) | ||
locals_dict[self._get_alias_name(alias)] = imported_obj | ||
|
||
elif isinstance(node, ast.Call): | ||
# Add callable to the set of dependencies if it's user defined and continue the recursive search within those callables. | ||
func_name = self._get_callable_name(node.func) | ||
if func_name and func_name not in visited: | ||
visited.add(func_name) | ||
try: | ||
# Attempt to resolve the callable object using locals first, then globals | ||
func_obj = self._resolve_callable(func_name, locals_dict) or self._resolve_callable( | ||
func_name, func.__globals__ | ||
) | ||
# If the callable is a class and user-defined, we add and search all method. We also include attributes as potential constants. | ||
if inspect.isclass(func_obj) and self._is_user_defined(func_obj): | ||
current_class_attributes = { | ||
f"class.{func_name}.{name}": value for name, value in func_obj.__dict__.items() | ||
} | ||
for name, method in inspect.getmembers(func_obj, predicate=inspect.isfunction): | ||
if method not in visited: | ||
visited.add(method.__qualname__) | ||
dependencies.add(method) | ||
dependencies.update( | ||
self._get_function_dependencies(method, visited, current_class_attributes) | ||
) | ||
# If the callable is a function or method and user-defined, add it as a dependency and search its dependencies | ||
elif (inspect.isfunction(func_obj) or inspect.ismethod(func_obj)) and self._is_user_defined( | ||
func_obj | ||
): | ||
# Add the function or method as a dependency | ||
dependencies.add(func_obj) | ||
# Recursively search the function or method's dependencies | ||
dependencies.update(self._get_function_dependencies(func_obj, visited)) | ||
except (NameError, AttributeError) as e: | ||
click.secho(f"Could not process the callable {func_name} due to error: {str(e)}", fg="yellow") | ||
|
||
# Extract potential constants from the global import context | ||
global_constants = {} | ||
for key, value in func.__globals__.items(): | ||
if hasattr(value, "__dict__"): | ||
module_constants = self.get_module_literal_constants(value) | ||
global_constants.update({f"{key}.{name}": value for name, value in module_constants.items()}) | ||
elif self.is_literal_constant(value): | ||
global_constants[key] = value | ||
|
||
# Check for the usage of all potnential constants and update the set of constants to be hashed | ||
referenced_constants = self.get_referenced_constants( | ||
func=func, constant_imports=constant_imports, global_constants=global_constants | ||
) | ||
self.constants.update(referenced_constants) | ||
|
||
return dependencies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _get_function_dependencies
method is quite long (>100 lines) and handles multiple responsibilities including import handling, AST traversal, and constant extraction. Consider breaking it down into smaller focused methods for better maintainability.
Code suggestion
Check the AI-generated fix before applying
@@ -108,114 +108,50 @@
def _get_function_dependencies(
self, func: Callable[..., Any], visited: Set[str], class_attributes: dict = None
) -> Set[Callable[..., Any]]:
- dependencies = set()
- source = textwrap.dedent(inspect.getsource(func))
- parsed_ast = ast.parse(source)
-
- # Initialize dictionaries
- locals_dict = {}
- constant_imports = {}
- if class_attributes:
- constant_imports.update(class_attributes)
-
- # Check each function call in the AST
- for node in ast.walk(parsed_ast):
- if isinstance(node, ast.Import):
- # Handle imports
- for alias in node.names:
- module = importlib.import_module(alias.name)
- locals_dict[self._get_alias_name(alias)] = module
- module_constants = self.get_module_literal_constants(module)
- constant_imports.update(
- {f"{self._get_alias_name(alias)}.{name}": value for name, value in module_constants.items()}
- )
- elif isinstance(node, ast.ImportFrom):
- # Handle from imports
- module_name = node.module
- module = importlib.import_module(module_name)
- for alias in node.names:
- imported_obj = getattr(module, alias.name, None)
- if imported_obj:
- locals_dict[self._get_alias_name(alias)] = imported_obj
- if self.is_literal_constant(imported_obj):
- constant_imports.update({f"{self._get_alias_name(alias)}": imported_obj})
- else:
- submodule = importlib.import_module(f"{module_name}.{alias.name}")
- imported_obj = getattr(submodule, alias.name, None)
- locals_dict[self._get_alias_name(alias)] = imported_obj
+ dependencies = set()
+ source = textwrap.dedent(inspect.getsource(func))
+ parsed_ast = ast.parse(source)
+
+ locals_dict, constant_imports = self._initialize_dictionaries(class_attributes)
+
+ for node in ast.walk(parsed_ast):
+ if isinstance(node, ast.Import):
+ self._handle_imports(node, locals_dict, constant_imports)
+ elif isinstance(node, ast.ImportFrom):
+ self._handle_import_from(node, locals_dict, constant_imports)
+ elif isinstance(node, ast.Call):
+ self._process_callable(node, locals_dict, func, visited, dependencies)
+
+ self._extract_constants(func, constant_imports)
+ return dependencies
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
@@ -132,9 +133,9 @@ | |||
|
|||
@overload | |||
def task( | |||
_task_function: Callable[P, FuncOut], | |||
_task_function: Callable[..., FuncOut], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider if changing _task_function
type hint from Callable[P, FuncOut]
to Callable[..., FuncOut]
could make the type checking less strict. The ...
allows any arguments which may hide potential type errors at compile time. Similar issues were also found in:
- flytekit/core/workflow.py (line 846-864)
- flytekit/core/workflow.py (line 857-901)
Code suggestion
Check the AI-generated fix before applying
_task_function: Callable[..., FuncOut], | |
_task_function: Callable[P, FuncOut], |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
Given a function, generates a version hash based on its source code and the salt. | ||
""" | ||
|
||
def __init__(self, salt: str = "") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding validation for empty salt
parameter in __init__
. An empty salt could potentially lead to weaker caching behavior.
Code suggestion
Check the AI-generated fix before applying
@@ -23,8 +23,10 @@
def __init__(self, salt: str = "") -> None:
"""
Initialize the CacheFunctionBody instance with a salt value.
"""
- self.salt = salt
+ if not salt:
+ raise ValueError("Salt cannot be empty as it affects cache effectiveness")
+ self.salt = salt
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
def get_version(self, params: VersionParameters) -> str: | ||
if params.func is None: | ||
raise ValueError("Function-based cache requires a function parameter") | ||
return self._get_version(func=params.func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_version
method could benefit from type checking params.func
before accessing it to provide a more descriptive error message.
Code suggestion
Check the AI-generated fix before applying
def get_version(self, params: VersionParameters) -> str: | |
if params.func is None: | |
raise ValueError("Function-based cache requires a function parameter") | |
return self._get_version(func=params.func) | |
def get_version(self, params: VersionParameters) -> str: | |
if params.func is None: | |
raise ValueError("Function-based cache requires a function parameter") | |
if not callable(params.func): | |
raise TypeError("params.func must be a callable function") | |
return self._get_version(func=params.func) |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
module_b.another_helper() | ||
result = norm([1, 2, 3]) | ||
print(result) | ||
sum([SOME_CONSTANT, utils.THIRD_CONSTANT]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sum()
function call result is not being stored or used, which may indicate meaningless executed code. Consider either storing the result or removing if not needed.
Code suggestion
Check the AI-generated fix before applying
sum([SOME_CONSTANT, utils.THIRD_CONSTANT]) | |
result = sum([SOME_CONSTANT, utils.THIRD_CONSTANT]) |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
except Exception as e: | ||
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad 'Exception' may hide bugs. Consider catching specific exceptions instead.
Code suggestion
Check the AI-generated fix before applying
except Exception as e: | |
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow") | |
except (ImportError, AttributeError) as e: | |
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow") |
Code Review Run #bc105b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
flytekit/core/auto_cache.py
Outdated
... | ||
|
||
|
||
class CachePolicy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can CachePolicy
live in the plugin? It makes sense for the abstract AutoCache
protocol to be defined in flytekit core, but any implementation of it should be in the plugin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, thanks. i refactored this!
Signed-off-by: Daniel Sola <[email protected]>
Signed-off-by: Daniel Sola <[email protected]>
Code Review Agent Run #fa2b0bActionable Suggestions - 1
Review Details
|
if self.cache_version: | ||
return self.cache_version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code attempts to access self.cache_version
which appears to be undefined. Consider initializing this attribute in the __init__
method or removing this check if not needed.
Code suggestion
Check the AI-generated fix before applying
@@ -15,7 +15,8 @@
def __init__(
self,
auto_cache_policies: List[AutoCache] = None,
salt: str = "",
) -> None:
self.auto_cache_policies = auto_cache_policies or []
self.salt = salt
+ self.cache_version = None
Code Review Run #fa2b0b
Is this a valid issue, or was it incorrectly flagged by the Agent?
- it was incorrectly flagged
Why are the changes needed?
Make caching easier to use in flytekit by reducing cognitive burden of specifying cache versions
What changes were proposed in this pull request?
To use the caching mechanism in a Flyte task, you can define a
CachePolicy
that combines multiple caching strategies. Here’s an example of how to set it up:Salt Parameter
The
salt
parameter in theCachePolicy
adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task.Cache Implementations
Users can add any number of cache policies that implement the
AutoCache
protocol defined in@auto_cache.py
. Below are the implementations available so far:1. CacheFunctionBody
This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning.
2. CacheImage
This implementation includes the hash of the
container_image
object passed. If the image is specified as a name, that string is hashed. If it is anImageSpec
, the parametrization of theImageSpec
is hashed, allowing for precise versioning of the container image used in the task.3. CachePrivateModules
This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash.
It accounts for both
import
andfrom-import
statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored.4. CacheExternalDependencies
This implementation recursively searches through all the callables like
CachePrivateModules
, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning.How was this patch tested?
Unit tests for the following:
Setup process
Screenshots
Check all the applicable boxes
Related PRs
Docs link
Summary by Bito
This PR refactors Flytekit's caching mechanism by introducing a comprehensive auto-cache plugin that implements multiple strategies including function body hashing, container image versioning, and dependency tracking. The implementation migrates CachePolicy to a dedicated plugin, simplifying the core auto_cache module and streamlining cache parameter types while maintaining seamless integration with existing task and workflow decorators.Unit tests added: True
Estimated effort to review (1-5, lower is better): 5