diff --git a/agent/__init__.py b/agent/__init__.py index e8df281..8016aed 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -1 +1 @@ -from .util import create_create_agent \ No newline at end of file +from .util import create_agent_helper \ No newline at end of file diff --git a/agent/util.py b/agent/util.py index 4534b20..a301ab7 100644 --- a/agent/util.py +++ b/agent/util.py @@ -1,12 +1,12 @@ from recsim.agents import full_slate_q_agent -def create_create_agent(agent=full_slate_q_agent.FullSlateQAgent): +def create_agent_helper(agent=full_slate_q_agent.FullSlateQAgent, **kwargs): def create_agent(sess, environment, eval_mode, summary_writer=None): - kwargs = { - 'observation_space': environment.observation_space, - 'action_space': environment.action_space, - 'summary_writer': summary_writer, - 'eval_mode': eval_mode, - } + print(f"using {agent.__name__}") + kwargs['observation_space'] = environment.observation_space + kwargs['action_space'] = environment.action_space + kwargs['summary_writer'] = summary_writer + kwargs['eval_mode'] = eval_mode + return agent(sess, **kwargs) return create_agent \ No newline at end of file diff --git a/main.py b/main.py index 7650847..7cef75c 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ from document import FlashcardDocumentSampler from recsim.simulator import recsim_gym from recsim.agents import full_slate_q_agent from recsim.simulator import runner_lib -from agent import create_create_agent +from agent import create_agent_helper from util import reward, update_metrics slate_size = 1 @@ -14,7 +14,7 @@ time_budget = 60 tf.compat.v1.disable_eager_execution() -create_agent_fn = create_create_agent(full_slate_q_agent.FullSlateQAgent) +create_agent_fn = create_agent_helper(full_slate_q_agent.FullSlateQAgent) ltsenv = environment.Environment( FlashcardUserModel(num_candidates, time_budget, slate_size), @@ -31,7 +31,7 @@ runner = runner_lib.TrainRunner( base_dir=tmp_base_dir, create_agent_fn=create_agent_fn, env=lts_gym_env, - episode_log_file="", + episode_log_file="episode.log", max_training_steps=5, num_iterations=1 )