Source code for jboat.gm_ol.di

from jittor import Module
from ..higher_jit.patch import _MonkeyPatchBase
from ..higher_jit.optim import DifferentiableOptimizer
from typing import Dict, Any, Callable

from jboat.operation_registry import register_class
from jboat.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class DI(DynamicalSystem): """ Implements the lower-level optimization procedure for Dynamic Initialization [1]. Parameters ---------- ll_objective : Callable The lower-level objective function of the BLO problem. ul_objective : Callable The upper-level objective function of the BLO problem. ll_model : jittor.Module The lower-level model of the BLO problem. ul_model : jittor.Module The upper-level model of the BLO problem. lower_loop : int The number of iterations for the lower-level optimization process. solver_config : Dict[str, Any] A dictionary containing configurations for the optimization solver, including hyperparameters and specific settings for NGD and GDA. References ---------- [1] Liu R., Liu Y., Zeng S., et al. "Towards gradient-based bilevel optimization with non-convex followers and beyond," in NeurIPS, 2021. """ def __init__( self, ll_objective: Callable, ul_objective: Callable, ll_model: Module, ul_model: Module, lower_loop: int, solver_config: Dict[str, Any], ): super(DI, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config )
[docs] def optimize( self, ll_feed_dict: Dict, ul_feed_dict: Dict, auxiliary_model: _MonkeyPatchBase, auxiliary_opt: DifferentiableOptimizer, current_iter: int, next_operation: str = None, **kwargs ): """ Executes the lower-level optimization procedure using the provided data and models. Parameters ---------- ll_feed_dict : Dict[str, Any] Dictionary containing the lower-level data used for optimization. Typically includes: - "data" : The input data for lower-level optimization. - "target" : The target output (optional, depending on the task). ul_feed_dict : Dict[str, Any] Dictionary containing the upper-level data used for optimization. Typically includes: - "data" : The input data for upper-level optimization. - "target" : The target output (optional, depending on the task). auxiliary_model : _MonkeyPatchBase A patched lower model wrapped by the `higher` library. Serves as the lower-level model for optimization in a differentiable way. auxiliary_opt : DifferentiableOptimizer A patched optimizer for the lower-level model, wrapped by the `higher` library. Allows for differentiable optimization steps. current_iter : int The current iteration number of the optimization process. next_operation : str Specifies the next operation to execute during the optimization process. Must not be None. **kwargs : dict Additional arguments passed to the optimization procedure. Returns ------- Dict A dictionary containing the input parameters and any additional keyword arguments. Raises ------ AssertionError If `next_operation` is not defined. Notes ----- Ensure that `next_operation` is defined before calling this function to specify the next operation in the optimization pipeline. """ assert next_operation is not None, "Next operation should be defined." return { "ll_feed_dict": ll_feed_dict, "ul_feed_dict": ul_feed_dict, "auxiliary_model": auxiliary_model, "auxiliary_opt": auxiliary_opt, "current_iter": current_iter, **kwargs, }