spyrit.core.train.load_net

spyrit.core.train.load_net(path, model, device=None, strict=True)[source]

Loads network defined by path into model. The network is loaded in-place

Args:

path (str): full path to the model, must contain file extension

model (torch.nn.Module): model to load the weights into. The model must have the same architecture as the model that was saved.

device (str): device to load the model on. If None, the model is loaded on the cpu.

strict (bool): this argument is passed to the load_state_dict of the nn.Module. If True, the keys of the state_dict and the model must match exactly. If there is a mismatch, an exception is raised.

Returns:

None