Source code for jboat.na_ol.sequential_hg

from typing import List, Dict
from jboat.utils import HyperGradientRules, ResultStore
from jboat.operation_registry import get_registered_operation


[docs] class SequentialHG: """ A class for managing sequential hyper-gradient operations. This class dynamically organizes and executes a sequence of hyper-gradient computations using user-defined and validated orders of gradient operators. Parameters ---------- ordered_instances : List[object] A list of instantiated gradient operator objects, ordered as per the adjusted sequence. custom_order : List[str] The user-defined order of gradient operators. """ def __init__(self, ordered_instances: List[object], custom_order: List[str]): self.gradient_instances = ordered_instances self.custom_order = custom_order self.result_store = ResultStore() # Use a dedicated result store
[docs] def compute_gradients(self, **kwargs) -> List[Dict]: """ Compute hyper-gradients sequentially using the ordered instances. This method processes the hyper-gradients in the defined order, passing intermediate results between consecutive gradient operators. Parameters ---------- **kwargs : dict Additional arguments required for gradient computations. Returns ------- List[Dict] A list of dictionaries containing results for each gradient operator. """ self.result_store.clear() # Reset the result store intermediate_result = None for idx, gradient_instance in enumerate(self.gradient_instances): # Compute the gradient, passing the intermediate result as input intermediate_result = gradient_instance.compute_gradients( **(kwargs if idx == 0 else intermediate_result), next_operation=( self.custom_order[idx + 1] if idx + 1 < len(self.custom_order) else None ), ) self.result_store.add(f"gradient_operator_results", intermediate_result) return self.result_store.get_results()
[docs] def makes_functional_na_operation(custom_order: List[str], **kwargs) -> SequentialHG: """ Dynamically create a SequentialHG object with ordered gradient operators. This function validates the user-defined operator order, adjusts it to conform with predefined gradient rules, and dynamically loads the corresponding operator classes. Parameters ---------- custom_order : List[str] The user-defined order of gradient operators. **kwargs : dict Additional arguments required for initializing gradient operator instances. Returns ------- SequentialHG An instance of SequentialHG containing the ordered gradient operators and result management. """ # Load the predefined gradient order gradient_order = HyperGradientRules.get_gradient_order() # Adjust custom order based on predefined gradient order adjusted_order = validate_and_adjust_order(custom_order, gradient_order) # Dynamically load classes gradient_classes = {} # module = importlib.import_module("boat.na_ol") for op in custom_order: gradient_classes[op] = get_registered_operation(op) # Reorder classes according to adjusted order ordered_instances = [gradient_classes[op](**kwargs) for op in adjusted_order] # Return the enhanced sequential hyper-gradient class return SequentialHG(ordered_instances, custom_order)
[docs] def validate_and_adjust_order( custom_order: List[str], gradient_order: List[List[str]] ) -> List[str]: """ Validate and adjust the custom order to match the predefined gradient order. This function ensures that the user-defined order adheres to the predefined grouping rules and adjusts it accordingly. Parameters ---------- custom_order : List[str] The user-provided order of gradient operators. gradient_order : List[List[str]] The predefined order of gradient operator groups. Returns ------- List[str] Adjusted order of gradient operators following the predefined rules. """ # Create a set of valid operators for quick lookup valid_operators = {op for group in gradient_order for op in group} # Filter out invalid operators custom_order = [op for op in custom_order if op in valid_operators] # Adjust order to follow gradient_order adjusted_order = [] for group in gradient_order: for op in group: if op in custom_order: adjusted_order.append(op) return adjusted_order