L2 Regularization with Jittor
This runnable example shows how to use the JBOAT library with the Jittor backend to solve a bi-level optimization problem with L2 regularization, covering end-to-end data loading (sparse-to-dense conversion), model/optimizer setup, solver construction, and iterative training with evaluation.
Step-by-Step Explanation
Step 1: Imports & Path Setup
import argparse
import numpy as np
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
import jittor as jit
import jboat as boat
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_20newsgroups_vectorized
import json
Explanation:
Imports Jittor, JBOAT, and scikit-learn utilities for dataset loading and splitting.
Step 2: Configuration Loading
base_folder = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(base_folder, "configs_jit/boat_config_l2.json"), "r") as f:
boat_config = json.load(f)
with open(os.path.join(base_folder, "configs_jit/loss_config_l2.json"), "r") as f:
loss_config = json.load(f)
Explanation:
boat_config_l2.json: Contains configuration for the bi-level optimization problem.loss_config_l2.json: Defines the loss functions for both upper-level and lower-level models.
Step 3: Data Preparation
def get_data(args, max_samples=2000):
"""
Load and process data for Jittor, with optional downsampling.
"""
def from_sparse(x):
x = x.tocoo()
values = x.data
indices = np.vstack((x.row, x.col))
i = jit.array(indices, dtype=jit.int64)
v = jit.array(values, dtype=jit.float32)
shape = x.shape
dense_tensor = jit.zeros(shape, dtype=jit.float32)
dense_tensor[i[0], i[1]] = v
return dense_tensor
val_size = 0.5
train_x, train_y = fetch_20newsgroups_vectorized(
subset="train",
return_X_y=True,
data_home=args.data_path,
download_if_missing=True,
)
test_x, test_y = fetch_20newsgroups_vectorized(
subset="test",
return_X_y=True,
data_home=args.data_path,
download_if_missing=True,
)
# ---- New: subsampling to reduce dataset size ----
if max_samples is not None:
train_x = train_x[:max_samples]
train_y = train_y[:max_samples]
test_x = test_x[: max_samples // 2]
test_y = test_y[: max_samples // 2]
train_x, val_x, train_y, val_y = train_test_split(
train_x, train_y, stratify=train_y, test_size=val_size
)
test_x, teval_x, test_y, teval_y = train_test_split(
test_x, test_y, stratify=test_y, test_size=0.5
)
train_x, val_x, test_x, teval_x = map(from_sparse, [train_x, val_x, test_x, teval_x])
train_y, val_y, test_y, teval_y = map(
lambda y: jit.array(y, dtype=jit.int64), [train_y, val_y, test_y, teval_y]
)
print(train_y.shape[0], val_y.shape[0], test_y.shape[0], teval_y.shape[0])
return (train_x, train_y), (val_x, val_y), (test_x, test_y), (teval_x, teval_y)
Explanation:
The
get_datafunction loads the dataset, processes it to Jittor tensors, and splits it into training, validation, test, and evaluation sets.Processed data is saved to a file for future use.
Step 4: Evaluation Helper
def evaluate(x, w, testset):
"""
Evaluate the performance of the model on the test set.
Args:
x (jittor.Var): Model weights (used in matrix multiply).
w (jittor.Var): Upper-level variables (kept for interface consistency).
testset (tuple): Tuple containing test_x and test_y.
Returns:
tuple: Loss and accuracy of the model on the test set.
"""
with jit.no_grad():
test_x, test_y = testset
# logits
y = test_x @ x
# to numpy for accuracy
y_np = y.numpy()
test_y_np = test_y.numpy() if isinstance(test_y, jit.Var) else test_y
loss = jit.nn.cross_entropy_loss(y, jit.array(test_y_np)).item()
predicted = y_np.argmax(axis=-1)
acc = (predicted == test_y_np).sum() / len(test_y_np)
return loss, acc
Explanation:
Computes cross-entropy loss using Jittor.
Computes accuracy using NumPy for simplicity.
Step 5: Main Function & Argument Parsing
def main():
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--generate_data", action="store_true", default=False)
parser.add_argument("--pretrain", action="store_true", default=False)
parser.add_argument("--epochs", type=int, default=1000)
parser.add_argument("--iterations", type=int, default=10, help="T")
parser.add_argument("--data_path", default="./data", help="where to save data")
parser.add_argument("--model_path", default="./save_l2reg", help="where to save model")
parser.add_argument("--x_lr", type=float, default=100)
parser.add_argument("--xhat_lr", type=float, default=100)
parser.add_argument("--w_lr", type=float, default=1000)
parser.add_argument("--w_momentum", type=float, default=0.9)
parser.add_argument("--x_momentum", type=float, default=0.9)
parser.add_argument("--K", type=int, default=10, help="k")
parser.add_argument("--u1", type=float, default=1.0)
parser.add_argument("--BVFSM_decay", type=str, default="log", choices=["log", "power2"])
parser.add_argument("--seed", type=int, default=1)
parser.add_argument(
"--alg",
type=str,
default="BOME",
choices=["BOME", "BSG_1", "penalty", "AID_CG", "AID_FP", "ITD", "BVFSM",
"baseline", "VRBO", "reverse", "stocBiO", "MRBO"],
)
parser.add_argument("--gm_op", type=str, default="DM,NGD")
parser.add_argument("--na_op", type=str, default="RAD")
parser.add_argument("--fo_op", type=str, default=None)
args = parser.parse_args()
np.random.seed(args.seed)
jit.set_global_seed(args.seed)
return args
args = parse_args()
Explanation:
Defines CLI arguments controlling dataset path, strategies (
gm_op,na_op) orfo_op, and seeds.
Step 6: Data Setting and Model Initialization
trainset, valset, testset, tevalset = get_data(args)
jit.save(
(trainset, valset, testset, tevalset),
os.path.join(args.data_path, "l2reg.pkl")
)
print(f"[info] successfully generated data to {args.data_path}/l2reg.pkl")
class UpperModel(jit.Module):
def __init__(self, n_feats):
# Initialize learnable regularization parameters
self.x = jit.init.constant([n_feats], "float32", 0.0).clone()
def execute(self):
return self.x
class LowerModel(jit.Module):
def __init__(self, n_feats, num_classes):
# Initialize classifier weights
self.y = jit.zeros([n_feats, num_classes])
jit.init.kaiming_normal_(
self.y, a=0, mode="fan_in", nonlinearity="leaky_relu"
)
def execute(self):
return self.y
upper_model = UpperModel(trainset[0].shape[-1])
lower_model = LowerModel(trainset[0].shape[-1], int(trainset[1].max().item()) + 1)
Explanation:
Saves the processed dataset for reuse as
l2reg.pkl.UpperModel: Represents the upper-level model with a single learnable parameter.LowerModel: Represents the lower-level model initialized using the Kaiming initialization strategy.
Step 7: Optimizer & Strategy Setup
upper_opt = jit.nn.Adam(upper_model.parameters(), lr=0.01)
lower_opt = jit.nn.SGD(lower_model.parameters(), lr=0.01)
# Parse optimization strategies from arguments
gm_op = args.gm_op.split(",") if args.gm_op else None
na_op = args.na_op.split(",") if args.na_op else None
Explanation:
Adam optimizer: Used for the upper-level model to update its parameters.
SGD optimizer: Applied to the lower-level model for efficient gradient updates.
The
gm_opandna_opparameters allow flexible optimization strategies.
Step 8: Bi-Level Optimization Setup
# Configure JBOAT problem
if na_op is not None:
if "RGT" in na_op:
boat_config["RGT"]["truncate_iter"] = 1
boat_config["gm_op"] = gm_op
boat_config["na_op"] = na_op
boat_config["fo_op"] = args.fo_op
boat_config["lower_level_model"] = lower_model
boat_config["upper_level_model"] = upper_model
boat_config["lower_level_opt"] = lower_opt
boat_config["upper_level_opt"] = upper_opt
boat_config["lower_level_var"] = list(lower_model.parameters())
boat_config["upper_level_var"] = list(upper_model.parameters())
b_optimizer = boat.Problem(boat_config, loss_config)
b_optimizer.build_ll_solver()
b_optimizer.build_ul_solver()
Explanation:
Configures the
boat_configwith models, optimizers, and variables for both levels.Instantiates the
boat.Problemclass and builds the necessary lower-level and upper-level solvers.
Step 9: Optimization Loop
ul_feed_dict = {"data": trainset[0], "target": trainset[1]}
ll_feed_dict = {"data": valset[0], "target": valset[1]}
# Determine iteration count based on strategy
if "DM" in boat_config["gm_op"] and ("GDA" in boat_config["gm_op"]):
iterations = 3
else:
iterations = 2
for x_itr in range(iterations):
# Dynamic strategy adjustment for Dynamic Methods (DM)
if "DM" in boat_config["gm_op"] and boat_config["fo_op"] is None:
if "GDA" in boat_config["gm_op"]:
b_optimizer._ll_solver.gradient_instances[-1].strategy = "s" + str(x_itr % 3 + 1)
else:
b_optimizer._ll_solver.gradient_instances[-1].strategy = "s" + str(1)
loss, run_time = b_optimizer.run_iter(
ll_feed_dict, ul_feed_dict, current_iter=x_itr
)
Explanation:
The
evaluatefunction calculates the model’s loss and accuracy on the test dataset.Outputs the test performance metrics for monitoring optimization progress.
The
run_iterfunction iterates over the bi-level optimization process using JBOAT.
Step 10: Entry Point
if __name__ == "__main__":
main()
Explanation:
Standard Python entry point that makes the script runnable directly.
How to Run
To execute the example, use the following command:
python your_script_name.py --data_path ./data --model_path ./save_l2reg --gm_op NGD --na_op RAD