Source code for jboat.na_ol.rad

import jittor as jit
from jittor import Module
from typing import List, Callable, Dict
from ..higher_jit.patch import _MonkeyPatchBase
from jboat.utils.op_utils import update_tensor_grads

from jboat.operation_registry import register_class
from jboat.na_ol.hyper_gradient import HyperGradient


[docs] @register_class class RAD(HyperGradient): """ Computes the hyper-gradient of the upper-level variables using Reverse Auto Differentiation (RAD) [1]. Parameters ---------- ll_objective : Callable The lower-level objective function of the BLO problem. ul_objective : Callable The upper-level objective function of the BLO problem. ll_model : jittor.Module The lower-level model of the BLO problem. ul_model : jittor.Module The upper-level model of the BLO problem. ll_var : List[jittor.Var] List of variables optimized with the lower-level objective. ul_var : List[jittor.Var] List of variables optimized with the upper-level objective. solver_config : Dict[str, Any] Dictionary containing solver configurations, including optional dynamic operation settings. References ---------- [1] Franceschi, Luca, et al. "Forward and reverse gradient-based hyperparameter optimization." in ICML, 2017. """ def __init__( self, ll_objective: Callable, ul_objective: Callable, ll_model: Module, ul_model: Module, ll_var: List, ul_var: List, solver_config: Dict, ): super(RAD, self).__init__( ll_objective, ul_objective, ul_model, ll_model, ll_var, ul_var, solver_config, ) self.dynamic_initialization = "DI" in solver_config["gm_op"]
[docs] def compute_gradients( self, ll_feed_dict: Dict, ul_feed_dict: Dict, auxiliary_model: _MonkeyPatchBase, max_loss_iter: int = 0, next_operation: str = None, **kwargs ): """ Compute the hyper-gradients of the upper-level variables using the provided data and patched models. Parameters ---------- ll_feed_dict : Dict Dictionary containing the lower-level data used for optimization. Typically includes training data, targets, and other information required to compute the lower-level objective. ul_feed_dict : Dict Dictionary containing the upper-level data used for optimization. Typically includes validation data, targets, and other information required to compute the upper-level objective. auxiliary_model : _MonkeyPatchBase A patched lower model wrapped by the `higher` library. It serves as the lower-level model for optimization. max_loss_iter : int, optional The number of iterations used for backpropagation. Default is 0. next_operation : str, optional The next operator for the calculation of the hypergradient. Default is None. **kwargs : dict Additional keyword arguments passed to the method. Returns ------- Dict A dictionary containing the upper-level objective and the status of hypergradient computation. """ assert next_operation is None, "RAD does not support any further operations." lower_model_params = kwargs.get( "lower_model_params", list(auxiliary_model.parameters()) ) upper_loss = self.ul_objective( ul_feed_dict, self.ul_model, auxiliary_model, params=lower_model_params ) grads_upper = jit.grad( upper_loss, self.ul_var, retain_graph=self.dynamic_initialization ) update_tensor_grads(self.ul_var, grads_upper) if self.dynamic_initialization: grads_lower = jit.grad(upper_loss, list(auxiliary_model.parameters(time=0))) update_tensor_grads(self.ll_var, grads_lower) return {"upper_loss": upper_loss.item(), "hyper_gradient_finished": True}