Source code for jboat.gm_ol.dm

from .dynamical_system import DynamicalSystem
import jittor as jit
from jittor import Module
from ..higher_jit.patch import _MonkeyPatchBase
from ..higher_jit.optim import DifferentiableOptimizer
from typing import Dict, Any, Callable
from ..utils.op_utils import (
    update_tensor_grads,
    grad_unused_zero,
    list_tensor_norm,
    list_tensor_matmul,
    custom_grad,
    manual_update,
)

from jboat.operation_registry import register_class
from jboat.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class DM(DynamicalSystem): """ Implements the lower-level optimization procedure for Dual Multiplier (DM) [1]. Parameters ---------- ll_objective : Callable The lower-level objective function of the BLO problem. ul_objective : Callable The upper-level objective function of the BLO problem. ll_model : jittor.Module The lower-level model of the BLO problem. ul_model : jittor.Module The upper-level model of the BLO problem. lower_loop : int The number of iterations for the lower-level optimization process. solver_config : Dict[str, Any] A dictionary containing configurations for the optimization solver, including hyperparameters and specific settings for NGD, GDA, and DM. References ---------- [1] Liu R, Liu Y, Yao W, et al., "Averaged method of multipliers for bi-level optimization without lower-level strong convexity," ICML, 2023. """ def __init__( self, ll_objective: Callable, lower_loop: int, ul_model: Module, ul_objective: Callable, ll_model: Module, solver_config: Dict[str, Any], ): super(DM, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config ) self.truncate_max_loss_iter = "PTT" in solver_config["na_op"] self.alpha = solver_config["GDA"]["alpha_init"] self.alpha_decay = solver_config["GDA"]["alpha_decay"] self.truncate_iters = solver_config["RGT"]["truncate_iter"] if "RGT" in solver_config["na_op"] else 0 self.ll_opt = solver_config["lower_level_opt"] self.ul_opt = solver_config["upper_level_opt"] self.auxiliary_v = solver_config["DM"]["auxiliary_v"] self.auxiliary_v_opt = solver_config["DM"]["auxiliary_v_opt"] self.auxiliary_v_lr = solver_config["DM"]["auxiliary_v_lr"] self.tau = solver_config["DM"]["tau"] self.p = solver_config["DM"]["p"] self.mu0 = solver_config["DM"]["mu0"] self.eta = solver_config["DM"]["eta0"] self.strategy = solver_config["DM"]["strategy"] self.na_op = solver_config["na_op"] self.gda_loss = solver_config.get("gda_loss", None)
[docs] def optimize( self, ll_feed_dict: Dict, ul_feed_dict: Dict, auxiliary_model: _MonkeyPatchBase, auxiliary_opt: DifferentiableOptimizer, current_iter: int, next_operation: str = None, **kwargs ): """ Executes the lower-level optimization procedure with support for NGD, GDA, and RAD operations. Parameters ---------- ll_feed_dict : Dict Dictionary containing the lower-level data used for optimization. Typically includes: - "data" : Training input data. - "target" : Training target data (optional, depending on the task). ul_feed_dict : Dict Dictionary containing the upper-level data used for optimization. Typically includes: - "data" : Validation input data. - "target" : Validation target data (optional, depending on the task). auxiliary_model : _MonkeyPatchBase A patched lower model wrapped by the `higher` library. Used for differentiable optimization. auxiliary_opt : DifferentiableOptimizer A patched optimizer for the lower-level model, wrapped by the `higher` library. Enables differentiable optimization steps. current_iter : int The current iteration number in the optimization process. next_operation : str, optional Specifies the next operation in the optimization process. Must be `None` for NGD. (default: None) kwargs : dict Additional keyword arguments for the optimization process. Returns ------- int Returns `-1` upon successful completion of the optimization process. Notes ----- - For GDA operations, this method supports three strategies: 's1', 's2', and 's3'. - When using RAD in `na_op`, a higher-order gradient adjustment is applied to the auxiliary variables. - Ensure that `next_operation` is `None` for NGD, as it does not support additional operations. Raises ------ AssertionError If `next_operation` is not `None` for NGD or if an unsupported strategy is specified for GDA. """ assert next_operation is None, "NGD does not support next_operation" if "gda_loss" in kwargs: gda_loss = kwargs["gda_loss"] assert self.strategy in [ "s1", "s2", "s3", ], "Three strategies are supported for DM operation, including ['s1','s2','s3']." if self.strategy == "s1": self.alpha = self.mu0 * 1 / (current_iter + 1) ** (1 / self.p) self.eta = ( (current_iter + 1) ** (-0.5 * self.tau) * self.alpha**2 * self.ll_opt.defaults["lr"] ) x_lr = ( (current_iter + 1) ** (-1.5 * self.tau) * self.alpha**7 * self.ll_opt.defaults["lr"] ) elif self.strategy == "s2": self.alpha = self.mu0 * 1 / (current_iter + 1) ** (1 / self.p) self.eta = ( (current_iter + 1) ** (-0.5 * self.tau) * self.alpha * self.ll_opt.defaults["lr"] ) x_lr = ( (current_iter + 1) ** (-1.5 * self.tau) * self.alpha**5 * self.ll_opt.defaults["lr"] ) elif self.strategy == "s3": self.alpha = self.mu0 * 1 / (current_iter + 1) ** (1 / self.p) self.eta = (current_iter + 1) ** ( -0.5 * self.tau ) * self.ll_opt.defaults["lr"] x_lr = ( (current_iter + 1) ** (-1.5 * self.tau) * self.alpha**3 * self.ll_opt.defaults["lr"] ) for params in self.ul_opt.param_groups: params["lr"] = x_lr else: gda_loss = None assert ( self.strategy == "s1" ), "Only 's1' strategy is supported for DM without GDA operation." x_lr = ( self.ul_opt.defaults["lr"] * (current_iter + 1) ** (-self.tau) * self.ll_opt.defaults["lr"] ) eta = ( self.eta * (current_iter + 1) ** (-0.5 * self.tau) * self.ll_opt.defaults["lr"] ) for params in self.auxiliary_v_opt.param_groups: params["lr"] = eta for params in self.ul_opt.param_groups: params["lr"] = x_lr if gda_loss is not None: ll_feed_dict["alpha"] = self.alpha loss_full = self.gda_loss( ll_feed_dict, ul_feed_dict, self.ul_model, auxiliary_model ) else: loss_full = self.ll_objective(ll_feed_dict, self.ul_model, auxiliary_model) grad_y_temp = jit.grad( loss_full, list(auxiliary_model.parameters()), retain_graph=True ) upper_loss = self.ul_objective(ul_feed_dict, self.ul_model, auxiliary_model) grad_outer_params = grad_unused_zero( upper_loss, list(auxiliary_model.parameters()), retain_graph=True ) grads_phi_params = grad_unused_zero( loss_full, list(auxiliary_model.parameters()), retain_graph=True ) grads = custom_grad( grads_phi_params, list(self.ul_model.parameters()), self.auxiliary_v, retain_graph=True, ) # dx (dy f) v grad_outer_hparams = grad_unused_zero( upper_loss, list(self.ul_model.parameters()) ) if "RAD" in self.na_op: vsp = custom_grad( grads_phi_params, list(auxiliary_model.parameters()), grad_outputs=self.auxiliary_v, ) # dy (dy f) v=d2y f v for v0, v, gow in zip(self.auxiliary_v, vsp, grad_outer_params): v0._custom_grad = v - gow update_tensor_grads(list(self.ll_model.parameters()), grad_y_temp) manual_update(self.ll_opt, list(self.ll_model.parameters())) manual_update(self.auxiliary_v_opt, self.auxiliary_v) grads = [ -g + v if g is not None else v for g, v in zip(grads, grad_outer_hparams) ] update_tensor_grads(list(self.ul_model.parameters()), grads) else: vsp = custom_grad( grads_phi_params, list(auxiliary_model.parameters()), grad_outputs=self.auxiliary_v, ) tem = [v - gow for v, gow in zip(vsp, grad_outer_params)] ita_u = list_tensor_norm(tem) ** 2 grad_tem = custom_grad( grads_phi_params, list(auxiliary_model.parameters()), grad_outputs=tem ) ita_l = list_tensor_matmul(tem, grad_tem) ita = ita_u / (ita_l + 1e-12) self.auxiliary_v = [ v0 - ita * v + ita * gow for v0, v, gow in zip(self.auxiliary_v, vsp, grad_outer_params) ] vsp = custom_grad( grads_phi_params, list(auxiliary_model.parameters()), grad_outputs=self.auxiliary_v, ) for v0, v, gow in zip(self.auxiliary_v, vsp, grad_outer_params): v0._custom_grad = v - gow update_tensor_grads(list(self.ll_model.parameters()), grad_y_temp) manual_update(self.ll_opt, list(self.ll_model.parameters())) grads = [ -g + v if g is not None else v for g, v in zip(grads, grad_outer_hparams) ] update_tensor_grads(list(self.ul_model.parameters()), grads) return -1