blazefl.contrib.FedAvgThreadPoolClientTrainer#

class blazefl.contrib.FedAvgThreadPoolClientTrainer(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int, num_parallels: int)[source]#

Bases: ThreadPoolClientTrainer[FedAvgUplinkPackage, FedAvgDownlinkPackage]

__init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int, num_parallels: int) None[source]#

Methods

__init__(model_selector, model_name, ...)

get_client_device(cid)

local_process(payload, cid_list)

Process the downlink payload from the server for a list of client IDs.

progress_fn(it)

A no-op progress function that can be overridden to provide custom progress tracking.

train(model, model_parameters, train_loader, ...)

Train the model with the given training data loader.

uplink_package()

Prepare the data package to be sent from the client to the server.

worker(cid, device, payload, stop_event)

Process a single client's training task in a thread.

Attributes

num_parallels

device

device_count

cache

stop_event

progress_fn(it: list[Future[FedAvgUplinkPackage]]) Iterable[Future[FedAvgUplinkPackage]][source]#

A no-op progress function that can be overridden to provide custom progress tracking.

Parameters:

it (list[Future[UplinkPackage]]) – A list of Future objects representing the results of client processing.

Returns:

The original iterable.

Return type:

Iterable[Future[UplinkPackage]]

train(model: Module, model_parameters: Tensor, train_loader: DataLoader, device: str, epochs: int, lr: float, stop_event: Event, cid: int) FedAvgUplinkPackage[source]#

Train the model with the given training data loader.

Parameters:
  • model (torch.nn.Module) – The model to train.

  • model_parameters (torch.Tensor) – Initial global model parameters.

  • train_loader (DataLoader) – DataLoader for the training data.

  • device (str) – Device to run the training on.

  • epochs (int) – Number of local training epochs.

  • lr (float) – Learning rate for the optimizer.

Returns:

Uplink package containing updated model parameters and data size.

Return type:

FedAvgUplinkPackage

Prepare the data package to be sent from the client to the server.

Returns:

A list of data packages prepared for uplink transmission.

Return type:

list[UplinkPackage]

worker(cid: int, device: str, payload: FedAvgDownlinkPackage, stop_event: Event) FedAvgUplinkPackage[source]#

Process a single client’s training task in a thread.

Parameters:
  • cid (int) – The client ID.

  • device (str) – The device to use for processing this client.

  • payload (DownlinkPackage) – The data package received from the server.

  • stop_event (threading.Event) – Event to signal stopping the worker.

Returns:

The uplink package containing the client’s results.

Return type:

UplinkPackage