Source code for jboat.fo_ol.meso

from jboat.utils.op_utils import (
    grad_unused_zero,
    update_tensor_grads,
    copy_parameter_from_list,
    manual_update,
)
import numpy
from jittor import Module
import copy
from typing import Dict, Any, Callable, List
import jittor as jit

from jboat.operation_registry import register_class
from jboat.gm_ol.dynamical_system import DynamicalSystem


[docs] @register_class class MESO(DynamicalSystem): """ Implements the optimization procedure of Moreau Envelope based Single-loop Method (MESO) [1]. Parameters ---------- ll_objective : Callable The lower-level objective of the BLO problem. ul_objective : Callable The upper-level objective of the BLO problem. ll_model : jittor.Module The lower-level model of the BLO problem. ul_model : jittor.Module The upper-level model of the BLO problem. ll_var : List[jittor.Var] The list of lower-level variables of the BLO problem. ul_var : List[jittor.Var] The list of upper-level variables of the BLO problem. lower_loop : int Number of iterations for lower-level optimization. solver_config : Dict[str, Any] A dictionary containing solver configurations. Expected keys include: - "lower_level_opt": The optimizer for the lower-level model. - "MESO" (Dict): A dictionary containing the following keys: - "eta": Learning rate for the MESO optimization procedure. - "gamma_1": Regularization parameter for the MESO algorithm. - "c0": Initial constant for the update steps. - "y_hat_lr": Learning rate for optimizing the surrogate variable `y_hat`. References ---------- [1] Liu R, Liu Z, Yao W, et al. "Moreau Envelope for Nonconvex Bi-Level Optimization: A Single-loop and Hessian-free Solution Strategy," ICML, 2024. """ def __init__( self, ll_objective: Callable, lower_loop: int, ul_model: Module, ul_objective: Callable, ll_model: Module, ll_var: List, ul_var: List, solver_config: Dict[str, Any], ): super(MESO, self).__init__( ll_objective, ul_objective, lower_loop, ul_model, ll_model, solver_config ) self.ll_opt = solver_config["lower_level_opt"] self.ll_var = ll_var self.ul_var = ul_var self.y_loop = lower_loop self.eta = solver_config["MESO"]["eta"] self.gamma_1 = solver_config["MESO"]["gamma_1"] self.c0 = solver_config["MESO"]["c0"] self.y_hat = copy.deepcopy(self.ll_model) self.y_hat_opt = jit.optim.SGD( self.y_hat.parameters(), lr=solver_config["MESO"]["y_hat_lr"], momentum=0.9 )
[docs] def optimize(self, ll_feed_dict: Dict, ul_feed_dict: Dict, current_iter: int): """ Execute the optimization procedure with the data from feed_dict. Parameters ---------- ll_feed_dict : Dict Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective. ul_feed_dict : Dict Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective. current_iter : int The current iteration number of the optimization process. Returns ------- Dict A dictionary containing the upper-level objective and the status of hypergradient computation. """ if current_iter == 0: ck = 0.2 else: ck = numpy.power(current_iter + 1, 0.25) * self.c0 theta_loss = self.ll_objective(ll_feed_dict, self.ul_model, self.y_hat) grad_theta_parmaters = grad_unused_zero( theta_loss, list(self.y_hat.parameters()) ) errs = [] for a, b in zip( list(self.y_hat.parameters()), list(self.ll_model.parameters()) ): diff = a - b errs.append(diff) vs_param = [] for v0, gt, err in zip( list(self.y_hat.parameters()), grad_theta_parmaters, errs ): vs_param.append(v0 - self.eta * (gt + self.gamma_1 * err)) # upate \theta copy_parameter_from_list(self.y_hat, vs_param) reg = 0 for param1, param2 in zip(list(self.ll_model.parameters()), vs_param): diff = param1 - param2 reg += (diff**2).sum() # Jittor-compatible L2 norm calculation lower_loss = ( (1 / ck) * self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model) + self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model) - 0.5 * self.gamma_1 * reg ) grad_y_parmaters = grad_unused_zero( lower_loss, list(self.ll_model.parameters()) ) update_tensor_grads(self.ll_var, grad_y_parmaters) manual_update(self.ll_opt, self.ll_var) upper_loss = ( (1 / ck) * self.ul_objective(ul_feed_dict, self.ul_model, self.ll_model) + self.ll_objective(ll_feed_dict, self.ul_model, self.ll_model) - self.ll_objective(ll_feed_dict, self.ul_model, self.y_hat) ) grad_x_parmaters = grad_unused_zero( upper_loss, list(self.ul_model.parameters()) ) update_tensor_grads(self.ul_var, grad_x_parmaters) return {"upper_loss": upper_loss.item()}