Source code for jboat.boat_opt

import time
import copy
from typing import Dict, Any, Callable
from jboat.utils.op_utils import (
    copy_parameter_from_list,
    average_grad,
    manual_update,
)

import jittor as jit
from jittor.optim import Optimizer
import jboat.higher_jit as higher

importlib = __import__("importlib")
from jboat.operation_registry import get_registered_operation
from jboat.gm_ol import makes_functional_dynamical_system
from jboat.na_ol import makes_functional_na_operation

import matplotlib.pyplot as plt
import os
import json


def _load_loss_function(loss_config: Dict[str, Any]) -> Callable:
    module_name, func_name = loss_config["function"].rsplit(".", 1)
    module = importlib.import_module(module_name)
    func = getattr(module, func_name)
    return lambda *args, **kwargs: func(
        *args, **{**loss_config.get("params", {}), **kwargs}
    )


[docs] class Problem: """ Jittor-based bi-level optimization Problem, aligned in interface and logic with the Torch version. """ def __init__(self, config: Dict[str, Any], loss_config: Dict[str, Any]): self._fo_op = config["fo_op"] self._gm_op = config["gm_op"] self._na_op = config["na_op"] self._ll_model = config["lower_level_model"] self._ul_model = config["upper_level_model"] self._ll_var = config["lower_level_var"] self._ul_var = config["upper_level_var"] self.boat_configs = config if config["gm_op"] is not None: if "GDA" in config["gm_op"]: assert ( loss_config.get("gda_loss", None) is not None ), "Set the 'gda_loss' in loss_config properly." self.boat_configs["gda_loss"] = _load_loss_function( loss_config["gda_loss"] ) self._lower_loop = config.get("lower_iters", 10) self._lower_opt = self.boat_configs["lower_level_opt"] self._upper_opt = self.boat_configs["upper_level_opt"] self._ll_loss = _load_loss_function(loss_config["lower_level_loss"]) self._ul_loss = _load_loss_function(loss_config["upper_level_loss"]) self._ll_solver = None self._ul_solver = None self._lower_init_opt = None self._fo_op_solver = None self._log_results = [] self._track_opt_traj = False self.loss_log_path = config["loss_log_path"] self.loss_history = []
[docs] def build_ll_solver(self): if self.boat_configs["fo_op"] is None: assert (self.boat_configs["gm_op"] is not None) and ( self.boat_configs["na_op"] is not None ), "Set 'gm_op' and 'na_op' properly." self.check_status() # DM need auxiliary variables if "DM" in self._gm_op: self.boat_configs["DM"]["auxiliary_v"] = [ jit.zeros_like(param) for param in self._ll_var ] # these auxiliary variables require gradients for v in self.boat_configs["DM"]["auxiliary_v"]: v.start_grad() self.boat_configs["DM"]["auxiliary_v_opt"] = jit.nn.SGD( self.boat_configs["DM"]["auxiliary_v"], lr=self.boat_configs["DM"]["auxiliary_v_lr"], ) sorted_ops = sorted([op.upper() for op in self._gm_op]) self._ll_solver = makes_functional_dynamical_system( custom_order=sorted_ops, ll_objective=self._ll_loss, ul_objective=self._ul_loss, ll_model=self._ll_model, ul_model=self._ul_model, lower_loop=self._lower_loop, solver_config=self.boat_configs, ) # DI: dynamic initialization optimizer if "DI" in self.boat_configs["gm_op"]: opt_cls = type(self._upper_opt) di_lr = float(self.boat_configs["DI"]["lr"]) self._lower_init_opt = opt_cls(self._ll_var, di_lr) else: self._fo_op_solver = get_registered_operation( "%s" % self.boat_configs["fo_op"] )( ll_objective=self._ll_loss, ul_objective=self._ul_loss, ll_model=self._ll_model, ul_model=self._ul_model, lower_loop=self._lower_loop, ll_var=self._ll_var, ul_var=self._ul_var, solver_config=self.boat_configs, ) return self
[docs] def build_ul_solver(self): if self.boat_configs["fo_op"] is None: assert ( self.boat_configs["na_op"] is not None ), "Choose na_op properly when not using FOGM." sorted_ops = sorted([op.upper() for op in self._na_op]) self._ul_solver = makes_functional_na_operation( custom_order=sorted_ops, ul_objective=self._ul_loss, ll_objective=self._ll_loss, ll_model=self._ll_model, ul_model=self._ul_model, ll_var=self._ll_var, ul_var=self._ul_var, solver_config=self.boat_configs, ) else: # FOGO does not need UL solver assert ( self.boat_configs["fo_op"] is not None ), "FOGM is enabled, UL solver is handled by FOGM op." return self
[docs] def run_iter( self, ll_feed_dict: Dict[str, jit.Var], ul_feed_dict: Dict[str, jit.Var], current_iter: int, ) -> tuple: if self.boat_configs["fo_op"] is not None: start_time = time.perf_counter() # FIX: align with Torch version if self.boat_configs.get("fogm_batch_input", False): for batch_ll_feed_dict, batch_ul_feed_dict in zip( ll_feed_dict, ul_feed_dict ): self._log_results.append( self._fo_op_solver.optimize( batch_ll_feed_dict, batch_ul_feed_dict, current_iter ) ) else: self._log_results.append( self._fo_op_solver.optimize(ll_feed_dict, ul_feed_dict, current_iter) ) run_time = time.perf_counter() - start_time else: run_time = 0.0 if self.boat_configs["accumulate_grad"]: for batch_ll_feed_dict, batch_ul_feed_dict in zip( ll_feed_dict, ul_feed_dict ): with higher.innerloop_ctx( self._ll_model, self._lower_opt, copy_initial_weights=False, track_higher_grads=self._track_opt_traj, ) as (auxiliary_model, auxiliary_opt): forward_time = time.perf_counter() dynamic_results = self._ll_solver.optimize( ll_feed_dict=batch_ll_feed_dict, ul_feed_dict=batch_ul_feed_dict, auxiliary_model=auxiliary_model, auxiliary_opt=auxiliary_opt, current_iter=current_iter, ) self._log_results.append(dynamic_results) max_loss_iter = list(dynamic_results[-1].values())[-1] forward_time = time.perf_counter() - forward_time backward_time = time.perf_counter() if "DM" not in self._gm_op: self._log_results.append( self._ul_solver.compute_gradients( ll_feed_dict=batch_ll_feed_dict, ul_feed_dict=batch_ul_feed_dict, # FIX auxiliary_model=auxiliary_model, max_loss_iter=max_loss_iter, ) ) else: self._log_results.append( self._ul_loss( batch_ul_feed_dict, self._ul_model, auxiliary_model ) ) backward_time = time.perf_counter() - backward_time run_time += forward_time + backward_time else: with higher.innerloop_ctx( self._ll_model, self._lower_opt, copy_initial_weights=False, track_higher_grads=self._track_opt_traj, ) as (auxiliary_model, auxiliary_opt): forward_time = time.perf_counter() dynamic_results = self._ll_solver.optimize( ll_feed_dict=ll_feed_dict, ul_feed_dict=ul_feed_dict, auxiliary_model=auxiliary_model, auxiliary_opt=auxiliary_opt, current_iter=current_iter, ) self._log_results.append(dynamic_results) max_loss_iter = list(dynamic_results[-1].values())[-1] forward_time = time.perf_counter() - forward_time print("forward_time", forward_time) backward_time = time.perf_counter() if "DM" not in self._gm_op: self._log_results.append( self._ul_solver.compute_gradients( ll_feed_dict=ll_feed_dict, ul_feed_dict=ul_feed_dict, auxiliary_model=auxiliary_model, max_loss_iter=max_loss_iter, ) ) else: # DM:use UL loss to update upper variables self._log_results.append( self._ul_loss(ul_feed_dict, self._ul_model, auxiliary_model) ) backward_time = time.perf_counter() - backward_time print("backward_time", backward_time) if self.boat_configs["copy_last_param"]: copy_parameter_from_list( self._ll_model, list(auxiliary_model.parameters(time=-1)) ) # DI:use dynamic initialization optimizer to update lower variables if "DI" in self.boat_configs["gm_op"]: manual_update( self._lower_init_opt, self._lower_opt.param_groups[0]["params"] ) run_time = forward_time + backward_time if isinstance(ll_feed_dict, list): ll_fd = ll_feed_dict[0] ul_fd = ul_feed_dict[0] else: ll_fd = ll_feed_dict ul_fd = ul_feed_dict if not self.boat_configs["return_grad"]: manual_update(self._upper_opt, self._ul_var) else: ll_loss = self._ll_loss(ll_fd, self._ul_model, self._ll_model) ul_loss = self._ul_loss(ul_fd, self._ul_model, self._ll_model) print(f"ll_loss: {ll_loss.item()} ul_loss: {ul_loss.item()}") self.save_losses(current_iter=current_iter, ll_loss=ll_loss, ul_loss=ul_loss) return [var._custom_grad for var in list(self._ul_var)], run_time ll_loss = self._ll_loss(ll_fd, self._ul_model, self._ll_model) ul_loss = self._ul_loss(ul_fd, self._ul_model, self._ll_model) print(f"ll_loss: {ll_loss.item()} ul_loss: {ul_loss.item()}") self.save_losses(current_iter=current_iter, ll_loss=ll_loss, ul_loss=ul_loss) return self._log_results, run_time
[docs] def set_track_trajectory(self, track_traj=True): self._track_opt_traj = track_traj
[docs] def check_status(self): if any(item in self._na_op for item in ["PTT", "IAD", "RAD"]): self.set_track_trajectory(True) if "DM" in self.boat_configs["gm_op"]: assert (self.boat_configs["na_op"] == ["RAD"]) or ( self.boat_configs["na_op"] == ["CG"] ), "When 'DM' is chosen, set the 'truncate_iter' properly." if "RGT" in self.boat_configs["na_op"]: assert ( self.boat_configs["RGT"]["truncate_iter"] > 0 ), "When 'RGT' is chosen, set the 'truncate_iter' properly ." if self.boat_configs["accumulate_grad"]: assert ( "IAD" in self.boat_configs["na_op"] ), "When using 'accumulate_grad', only 'IAD' based methods are supported." if self.boat_configs["GDA"]["alpha_init"] > 0.0: assert ( 0.0 < self.boat_configs["GDA"]["alpha_decay"] <= 1.0 ), "Parameter 'alpha_decay' used in method BDA should be in the interval (0,1)." if "FD" in self._na_op: assert ( self.boat_configs["RGT"]["truncate_iter"] == 0 ), "One-stage method doesn't need trajectory truncation." def check_model_structure(base_model, meta_model): for param1, param2 in zip(base_model.parameters(), meta_model.parameters()): if (param1.shape != param2.shape) or (param1.dtype != param2.dtype): return False return True if "IAD" in self._na_op: assert check_model_structure(self._ll_model, self._ul_model), ( "With IAD or FOA operation, 'upper_level_model' and 'lower_level_model' have the same structure, " "and 'lower_level_var' and 'upper_level_var' are the same group of variables." ) assert (("DI" in self._gm_op) ^ ("IAD" in self._na_op)) or ( ("DI" not in self._gm_op) and ("IAD" not in self._na_op) ), "Only one of the 'PTT' and 'RGT' methods could be chosen." assert ( 0.0 <= self.boat_configs["GDA"]["alpha_init"] <= 1.0 ), "Parameter 'alpha' used in method BDA should be in the interval (0,1)." assert ( self.boat_configs["RGT"]["truncate_iter"] < self.boat_configs["lower_iters"] ), "The value of 'truncate_iter' shouldn't be greater than 'lower_loop'."
[docs] def plot_losses(self): iters = [x["iter"] for x in self.loss_history] ll_losses = [x["ll_loss"] for x in self.loss_history] ul_losses = [x["ul_loss"] for x in self.loss_history] fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].plot(iters, ll_losses, label="Lower-level Loss", color="blue") axes[0].set_xlabel("Iteration") axes[0].set_ylabel("Loss") axes[0].set_title("Lower-level Loss") axes[0].legend(loc="upper left") axes[0].grid(True) axes[1].plot(iters, ul_losses, label="Upper-level Loss", color="orange") axes[1].set_xlabel("Iteration") axes[1].set_ylabel("Loss") axes[1].set_title("Upper-level Loss") axes[1].legend(loc="upper left") axes[1].grid(True) plt.tight_layout() save_path = os.path.join(os.path.dirname(self.loss_log_path), "loss_curve.png") plt.savefig(save_path) plt.close()
[docs] def save_losses(self, current_iter, ll_loss, ul_loss): self.loss_history.append( { "iter": current_iter, "ll_loss": float(ll_loss.item()), "ul_loss": float(ul_loss.item()), } ) with open(self.loss_log_path, "w") as f: json.dump(self.loss_history, f)