blazefl.contrib.FedAvgBaseClientTrainer#
- class blazefl.contrib.FedAvgBaseClientTrainer(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int)[source]#
Bases:
BaseClientTrainer
[FedAvgUplinkPackage
,FedAvgDownlinkPackage
]Base client trainer for the Federated Averaging (FedAvg) algorithm.
This trainer processes clients sequentially, training and evaluating a local model for each client based on the server-provided model parameters.
- model#
The client’s local model.
- Type:
torch.nn.Module
- dataset#
Dataset partitioned across clients.
- Type:
- device#
Device to run the model on (‘cpu’ or ‘cuda’).
- Type:
str
- num_clients#
Total number of clients in the federation.
- Type:
int
- epochs#
Number of local training epochs per client.
- Type:
int
- batch_size#
Batch size for local training.
- Type:
int
- lr#
Learning rate for the optimizer.
- Type:
float
- cache#
Cache to store uplink packages for the
- Type:
list[FedAvgUplinkPackage]
- server.
- __init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int) None [source]#
Initialize the FedAvgBaseClientTrainer.
- Parameters:
model_selector (ModelSelector) – Selector for initializing the local model.
model_name (str) – Name of the model to be used.
dataset (PartitionedDataset) – Dataset partitioned across clients.
device (str) – Device to run the model on (‘cpu’ or ‘cuda’).
num_clients (int) – Total number of clients in the federation.
epochs (int) – Number of local training epochs per client.
batch_size (int) – Batch size for local training.
lr (float) – Learning rate for the optimizer.
seed (int) – Seed for reproducibility.
Methods
__init__
(model_selector, model_name, ...)Initialize the FedAvgBaseClientTrainer.
local_process
(payload, cid_list)Train and evaluate the model for each client in the given list.
train
(model_parameters, train_loader, cid)Train the local model on the given training data loader.
Retrieve the uplink packages for transmission to the server.
- local_process(payload: FedAvgDownlinkPackage, cid_list: list[int]) None [source]#
Train and evaluate the model for each client in the given list.
- Parameters:
payload (FedAvgDownlinkPackage) – Downlink package with global model
parameters.
cid_list (list[int]) – List of client IDs to process.
- Returns:
None
- train(model_parameters: Tensor, train_loader: DataLoader, cid: int) FedAvgUplinkPackage [source]#
Train the local model on the given training data loader.
- Parameters:
model_parameters (torch.Tensor) – Global model parameters to initialize the
model. (local)
train_loader (DataLoader) – DataLoader for the training data.
- Returns:
Uplink package containing updated model parameters and data size.
- Return type:
- uplink_package() list[FedAvgUplinkPackage] [source]#
Retrieve the uplink packages for transmission to the server.
- Returns:
A list of uplink packages.
- Return type:
list[FedAvgUplinkPackage]