summaryrefslogtreecommitdiff
path: root/modeling/model_env.py
diff options
context:
space:
mode:
Diffstat (limited to 'modeling/model_env.py')
-rw-r--r--modeling/model_env.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/modeling/model_env.py b/modeling/model_env.py
index 433b2da..45b3207 100644
--- a/modeling/model_env.py
+++ b/modeling/model_env.py
@@ -13,6 +13,9 @@ class Stimuli:
training_order_list = []
test_order_list = []
+ current_proc_id = 1
+ current_phase = "train"
+
def __init__(self, condition, training_N=75, test_N=50):
self.condition = condition
self.training_N = training_N
@@ -44,15 +47,21 @@ class Stimuli:
)
if self.condition == "random":
- self.training_order_list = self.order_list[:self.training_N]
- self.test_order_list = self.order_list[self.training_N:]
+ self.training_order_list = self.order_list[: self.training_N]
+ self.test_order_list = self.order_list[self.training_N :]
def next_stimulus(self):
self.current_stimulus_id += 1
- if self.current_stimulus_id < self.training_N:
- return self.training_stimuli[self.current_stimulus_id]
- else:
- return self.test_stimuli[self.current_stimulus_id - self.training_N]
+ if self.condition != "blocked":
+ if self.current_stimulus_id < self.training_N:
+ return self.training_stimuli[self.current_stimulus_id]
+ else:
+ return self.test_stimuli[self.current_stimulus_id - self.training_N]
+ elif self.condition == "blocked":
+ if self.current_stimulus_id > 6 * self.training_N:
+ return self.test_stimuli[self.current_stimulus_id % 6 - self.training_N]
+ else:
+ return self.training_stimuli[self.current_stimulus_id % 6]
def update_current_stimulus(self, key, value):
if self.current_stimulus_id < self.training_N:
@@ -64,7 +73,6 @@ class Stimuli:
] = value
return self.test_stimuli[self.current_stimulus_id - self.training_N]
-
def generate_environments(self, water_samples, order):
envs = []