Source code for jboat.gm_ol.gda

from jittor import Module
from typing import Callable
from ..higher_jit.patch import _MonkeyPatchBase
from ..higher_jit.optim import DifferentiableOptimizer
from typing import Dict, Any, Callable

from jboat.operation_registry import register_class
from jboat.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class GDA(DynamicalSystem): """ Implements the optimization procedure of the Gradient Descent Aggregation (GDA) [1]. Parameters ---------- ll_objective : Callable The lower-level objective function of the BLO problem. ul_objective : Callable The upper-level objective function of the BLO problem. ll_model : jittor.Module The lower-level model of the BLO problem. ul_model : jittor.Module The upper-level model of the BLO problem. lower_loop : int The number of iterations for lower-level optimization. solver_config : Dict[str, Any] A dictionary containing configurations for the solver. Expected keys include: - "GDA" (Dict): Configuration for the GDA algorithm: - "alpha_init" (float): Initial learning rate for the GDA updates. - "alpha_decay" (float): Decay rate for the learning rate. - "gda_loss" (Callable): The loss function used in the GDA optimization. References ---------- [1] R. Liu, P. Mu, X. Yuan, S. Zeng, and J. Zhang, "A generic first-order algorithmic framework for bi-level programming beyond lower-level singleton", in ICML, 2020. """ def __init__( self, ll_objective: Callable, ul_objective: Callable, ll_model: Module, ul_model: Module, lower_loop: int, solver_config: Dict[str, Any], ): super(GDA, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config ) self.alpha = solver_config["GDA"]["alpha_init"] self.alpha_decay = solver_config["GDA"]["alpha_decay"] self.gda_loss = solver_config["gda_loss"]
[docs] def optimize( self, ll_feed_dict: Dict, ul_feed_dict: Dict, auxiliary_model: _MonkeyPatchBase, auxiliary_opt: DifferentiableOptimizer, current_iter: int, next_operation: str = None, **kwargs ): """ Execute the lower-level optimization procedure using provided data, models, and patched optimizers. Parameters ---------- ll_feed_dict : Dict Dictionary containing the lower-level data used for optimization. Typically includes training data, targets, and other information required to compute the LL objective. ul_feed_dict : Dict Dictionary containing the upper-level data used for optimization. Typically includes validation data, targets, and other information required to compute the UL objective. auxiliary_model : _MonkeyPatchBase A patched lower model wrapped by the `higher` library. This model is used for differentiable optimization in the lower-level procedure. auxiliary_opt : DifferentiableOptimizer A patched optimizer for the lower-level model, wrapped by the `higher` library. Allows differentiable optimization. current_iter : int The current iteration number of the optimization process. next_operation : str, optional Specifies the next operation to be executed in the optimization pipeline. Default is None. **kwargs : dict Additional parameters required for the optimization procedure. Returns ------- dict A dictionary containing: - "ll_feed_dict" : Dict Lower-level feed dictionary. - "ul_feed_dict" : Dict Upper-level feed dictionary. - "auxiliary_model" : _MonkeyPatchBase Patched lower-level model. - "auxiliary_opt" : DifferentiableOptimizer Patched lower-level optimizer. - "current_iter" : int Current iteration number. - "gda_loss" : callable Gradient Descent Aggregation (GDA) loss function. - "alpha" : float Coefficient used in the GDA operation, typically in (0, 1). - "alpha_decay" : float Decay factor for the coefficient `alpha`. Raises ------ AssertionError If `next_operation` is not defined. If `alpha` is not in the range (0, 1). If `gda_loss` is not properly defined. Notes ----- - The method assumes that `gda_loss` is defined and accessible from the instance attributes. - The coefficient `alpha` and its decay rate `alpha_decay` must be properly configured. """ assert next_operation is not None, "Next operation should be defined." assert (self.alpha > 0) and ( self.alpha < 1 ), "Set the coefficient alpha properly in (0,1)." assert ( self.gda_loss is not None ), "Define the gda_loss properly in loss_func.py." return { "ll_feed_dict": ll_feed_dict, "ul_feed_dict": ul_feed_dict, "auxiliary_model": auxiliary_model, "auxiliary_opt": auxiliary_opt, "current_iter": current_iter, "gda_loss": self.gda_loss, "alpha": self.alpha, "alpha_decay": self.alpha_decay, **kwargs, }