diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 359932c7..a12d2eb7 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -44,6 +44,10 @@ def set_random_seed(seed: int) -> Union[None, int]: # if torch not installed don't set random seed. if sys.modules["torch"]: th.manual_seed(seed) + + th.backends.cudnn.deterministic = True + th.backends.cudnn.benchmark = False + return seed