diff options
Diffstat (limited to 'modeling/model_env.py')
-rw-r--r-- | modeling/model_env.py | 22 |
1 files changed, 15 insertions, 7 deletions
diff --git a/modeling/model_env.py b/modeling/model_env.py index 433b2da..45b3207 100644 --- a/modeling/model_env.py +++ b/modeling/model_env.py @@ -13,6 +13,9 @@ class Stimuli: training_order_list = [] test_order_list = [] + current_proc_id = 1 + current_phase = "train" + def __init__(self, condition, training_N=75, test_N=50): self.condition = condition self.training_N = training_N @@ -44,15 +47,21 @@ class Stimuli: ) if self.condition == "random": - self.training_order_list = self.order_list[:self.training_N] - self.test_order_list = self.order_list[self.training_N:] + self.training_order_list = self.order_list[: self.training_N] + self.test_order_list = self.order_list[self.training_N :] def next_stimulus(self): self.current_stimulus_id += 1 - if self.current_stimulus_id < self.training_N: - return self.training_stimuli[self.current_stimulus_id] - else: - return self.test_stimuli[self.current_stimulus_id - self.training_N] + if self.condition != "blocked": + if self.current_stimulus_id < self.training_N: + return self.training_stimuli[self.current_stimulus_id] + else: + return self.test_stimuli[self.current_stimulus_id - self.training_N] + elif self.condition == "blocked": + if self.current_stimulus_id > 6 * self.training_N: + return self.test_stimuli[self.current_stimulus_id % 6 - self.training_N] + else: + return self.training_stimuli[self.current_stimulus_id % 6] def update_current_stimulus(self, key, value): if self.current_stimulus_id < self.training_N: @@ -64,7 +73,6 @@ class Stimuli: ] = value return self.test_stimuli[self.current_stimulus_id - self.training_N] - def generate_environments(self, water_samples, order): envs = [] |