blazefl.contrib.FedAvgBaseClientTrainer#

class blazefl.contrib.FedAvgBaseClientTrainer(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int)[source]#

Bases: BaseClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]

Base client trainer for the Federated Averaging (FedAvg) algorithm.

This trainer processes clients sequentially, training and evaluating a local model for each client based on the server-provided model parameters.

model#

The client’s local model.

Type:

torch.nn.Module

dataset#

Dataset partitioned across clients.

Type:

PartitionedDataset

device#

Device to run the model on (‘cpu’ or ‘cuda’).

Type:

str

num_clients#

Total number of clients in the federation.

Type:

int

epochs#

Number of local training epochs per client.

Type:

int

batch_size#

Batch size for local training.

Type:

int

lr#

Learning rate for the optimizer.

Type:

float

cache#

Cache to store uplink packages for the

Type:

list[FedAvgUplinkPackage]

server.
__init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int) None[source]#

Initialize the FedAvgBaseClientTrainer.

Parameters:
  • model_selector (ModelSelector) – Selector for initializing the local model.

  • model_name (str) – Name of the model to be used.

  • dataset (PartitionedDataset) – Dataset partitioned across clients.

  • device (str) – Device to run the model on (‘cpu’ or ‘cuda’).

  • num_clients (int) – Total number of clients in the federation.

  • epochs (int) – Number of local training epochs per client.

  • batch_size (int) – Batch size for local training.

  • lr (float) – Learning rate for the optimizer.

  • seed (int) – Seed for reproducibility.

Methods

__init__(model_selector, model_name, ...)

Initialize the FedAvgBaseClientTrainer.

local_process(payload, cid_list)

Train and evaluate the model for each client in the given list.

train(model_parameters, train_loader, cid)

Train the local model on the given training data loader.

uplink_package()

Retrieve the uplink packages for transmission to the server.

local_process(payload: FedAvgDownlinkPackage, cid_list: list[int]) None[source]#

Train and evaluate the model for each client in the given list.

Parameters:
  • payload (FedAvgDownlinkPackage) – Downlink package with global model

  • parameters.

  • cid_list (list[int]) – List of client IDs to process.

Returns:

None

train(model_parameters: Tensor, train_loader: DataLoader, cid: int) FedAvgUplinkPackage[source]#

Train the local model on the given training data loader.

Parameters:
  • model_parameters (torch.Tensor) – Global model parameters to initialize the

  • model. (local)

  • train_loader (DataLoader) – DataLoader for the training data.

Returns:

Uplink package containing updated model parameters and data size.

Return type:

FedAvgUplinkPackage

Retrieve the uplink packages for transmission to the server.

Returns:

A list of uplink packages.

Return type:

list[FedAvgUplinkPackage]