blazefl.contrib.FedAvgThreadPoolClientTrainer#
- class blazefl.contrib.FedAvgThreadPoolClientTrainer(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int, num_parallels: int)[source]#
Bases:
ThreadPoolClientTrainer
[FedAvgUplinkPackage
,FedAvgDownlinkPackage
]- __init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset[FedAvgPartitionType], device: str, num_clients: int, epochs: int, batch_size: int, lr: float, seed: int, num_parallels: int) None [source]#
Methods
__init__
(model_selector, model_name, ...)get_client_device
(cid)local_process
(payload, cid_list)Process the downlink payload from the server for a list of client IDs.
progress_fn
(it)A no-op progress function that can be overridden to provide custom progress tracking.
train
(model, model_parameters, train_loader, ...)Train the model with the given training data loader.
Prepare the data package to be sent from the client to the server.
worker
(cid, device, payload, stop_event)Process a single client's training task in a thread.
Attributes
num_parallels
device
device_count
cache
stop_event
- progress_fn(it: list[Future[FedAvgUplinkPackage]]) Iterable[Future[FedAvgUplinkPackage]] [source]#
A no-op progress function that can be overridden to provide custom progress tracking.
- Parameters:
it (list[Future[UplinkPackage]]) – A list of Future objects representing the results of client processing.
- Returns:
The original iterable.
- Return type:
Iterable[Future[UplinkPackage]]
- train(model: Module, model_parameters: Tensor, train_loader: DataLoader, device: str, epochs: int, lr: float, stop_event: Event, cid: int) FedAvgUplinkPackage [source]#
Train the model with the given training data loader.
- Parameters:
model (torch.nn.Module) – The model to train.
model_parameters (torch.Tensor) – Initial global model parameters.
train_loader (DataLoader) – DataLoader for the training data.
device (str) – Device to run the training on.
epochs (int) – Number of local training epochs.
lr (float) – Learning rate for the optimizer.
- Returns:
Uplink package containing updated model parameters and data size.
- Return type:
- uplink_package() list[FedAvgUplinkPackage] [source]#
Prepare the data package to be sent from the client to the server.
- Returns:
A list of data packages prepared for uplink transmission.
- Return type:
list[UplinkPackage]
- worker(cid: int, device: str, payload: FedAvgDownlinkPackage, stop_event: Event) FedAvgUplinkPackage [source]#
Process a single client’s training task in a thread.
- Parameters:
cid (int) – The client ID.
device (str) – The device to use for processing this client.
payload (DownlinkPackage) – The data package received from the server.
stop_event (threading.Event) – Event to signal stopping the worker.
- Returns:
The uplink package containing the client’s results.
- Return type:
UplinkPackage