Source code for jboat.gm_ol.sequential_ds

from typing import List, Dict
from jboat.utils import DynamicalSystemRules, ResultStore
from jboat.operation_registry import get_registered_operation

importlib = __import__("importlib")


[docs] class SequentialDS: """ A dynamically created class for sequential hyper-gradient operations. Attributes ---------- gradient_instances : List[object] A list of gradient operator instances, each implementing an `optimize` method. custom_order : List[str] A custom-defined order for executing the gradient operators. result_store : ResultStore An instance of the `ResultStore` class for storing intermediate and final results. """ def __init__(self, ordered_instances: List[object], custom_order: List[str]): """ Initialize the SequentialDS class with gradient operator instances and a custom execution order. Parameters ---------- ordered_instances : List[object] A list of gradient operator instances to be executed sequentially. custom_order : List[str] A list defining the custom execution order of the gradient operators. """ self.gradient_instances = ordered_instances self.custom_order = custom_order self.result_store = ResultStore() # Use a dedicated result store
[docs] def optimize(self, **kwargs) -> List[Dict]: """ Compute gradients sequentially using the ordered gradient operator instances. Parameters ---------- **kwargs : dict Arbitrary keyword arguments required for gradient computations. Returns ------- List[Dict] A list of dictionaries containing results for each gradient operator. Notes ----- - The results of each gradient operator are passed as inputs to the subsequent operator. - Only the final result is stored in the `ResultStore`. """ self.result_store.clear() # Reset the result store intermediate_result = None for idx, gradient_instance in enumerate(self.gradient_instances): intermediate_result = gradient_instance.optimize( **(kwargs if idx == 0 else intermediate_result), next_operation=( self.custom_order[idx + 1] if idx + 1 < len(self.custom_order) else None ), ) # only store the final result self.result_store.add(f"dynamic_results_{idx}", intermediate_result) return self.result_store.get_results()
[docs] def makes_functional_dynamical_system( custom_order: List[str], **kwargs ) -> SequentialDS: """ Dynamically create a SequentialHyperGradient object with ordered gradient operators. Parameters ---------- custom_order : List[str] User-defined operator order. Returns ------- SequentialHyperGradient An instance with ordered gradient operators and result management. """ # Load the predefined gradient order gradient_order = DynamicalSystemRules.get_gradient_order() # Adjust custom order based on predefined gradient order adjusted_order = validate_and_adjust_order(custom_order, gradient_order) # Dynamically load classes gradient_classes = {} # module = importlib.import_module("boat.gm_ol") for op in custom_order: gradient_classes[op] = get_registered_operation(op) # Reorder classes according to adjusted order ordered_instances = [gradient_classes[op](**kwargs) for op in adjusted_order] # Return the enhanced sequential hyper-gradient class return SequentialDS(ordered_instances, custom_order)
[docs] def validate_and_adjust_order( custom_order: List[str], gradient_order: List[List[str]] ) -> List[str]: """ Validate and adjust the custom order to align with the predefined gradient operator groups. Parameters ---------- custom_order : List[str] The user-defined order of gradient operators. gradient_order : List[List[str]] The predefined grouping of gradient operators, specifying valid order constraints. Returns ------- List[str] A validated and adjusted list of gradient operators that conforms to the predefined order. Notes ----- - The function filters out invalid operators from `custom_order` that do not exist in `gradient_order`. - It ensures that the returned order follows the precedence rules defined in `gradient_order`. Example ------- >>> custom_order = ["op1", "op3", "op2"] >>> gradient_order = [["op1", "op2"], ["op3"]] >>> adjusted_order = validate_and_adjust_order(custom_order, gradient_order) >>> print(adjusted_order) ['op1', 'op2', 'op3'] """ # Create a set of valid operators for quick lookup valid_operators = {op for group in gradient_order for op in group} # Filter out invalid operators custom_order = [op for op in custom_order if op in valid_operators] # Adjust order to follow gradient_order adjusted_order = [] for group in gradient_order: for op in group: if op in custom_order: adjusted_order.append(op) return adjusted_order