change create_agent interface
This commit is contained in:
parent
5d87afdcc0
commit
84a08b08ee
@ -1 +1 @@
|
|||||||
from .util import create_create_agent
|
from .util import create_agent_helper
|
@ -1,12 +1,12 @@
|
|||||||
from recsim.agents import full_slate_q_agent
|
from recsim.agents import full_slate_q_agent
|
||||||
|
|
||||||
def create_create_agent(agent=full_slate_q_agent.FullSlateQAgent):
|
def create_agent_helper(agent=full_slate_q_agent.FullSlateQAgent, **kwargs):
|
||||||
def create_agent(sess, environment, eval_mode, summary_writer=None):
|
def create_agent(sess, environment, eval_mode, summary_writer=None):
|
||||||
kwargs = {
|
print(f"using {agent.__name__}")
|
||||||
'observation_space': environment.observation_space,
|
kwargs['observation_space'] = environment.observation_space
|
||||||
'action_space': environment.action_space,
|
kwargs['action_space'] = environment.action_space
|
||||||
'summary_writer': summary_writer,
|
kwargs['summary_writer'] = summary_writer
|
||||||
'eval_mode': eval_mode,
|
kwargs['eval_mode'] = eval_mode
|
||||||
}
|
|
||||||
return agent(sess, **kwargs)
|
return agent(sess, **kwargs)
|
||||||
return create_agent
|
return create_agent
|
6
main.py
6
main.py
@ -5,7 +5,7 @@ from document import FlashcardDocumentSampler
|
|||||||
from recsim.simulator import recsim_gym
|
from recsim.simulator import recsim_gym
|
||||||
from recsim.agents import full_slate_q_agent
|
from recsim.agents import full_slate_q_agent
|
||||||
from recsim.simulator import runner_lib
|
from recsim.simulator import runner_lib
|
||||||
from agent import create_create_agent
|
from agent import create_agent_helper
|
||||||
from util import reward, update_metrics
|
from util import reward, update_metrics
|
||||||
|
|
||||||
slate_size = 1
|
slate_size = 1
|
||||||
@ -14,7 +14,7 @@ time_budget = 60
|
|||||||
|
|
||||||
tf.compat.v1.disable_eager_execution()
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
|
||||||
create_agent_fn = create_create_agent(full_slate_q_agent.FullSlateQAgent)
|
create_agent_fn = create_agent_helper(full_slate_q_agent.FullSlateQAgent)
|
||||||
|
|
||||||
ltsenv = environment.Environment(
|
ltsenv = environment.Environment(
|
||||||
FlashcardUserModel(num_candidates, time_budget, slate_size),
|
FlashcardUserModel(num_candidates, time_budget, slate_size),
|
||||||
@ -31,7 +31,7 @@ runner = runner_lib.TrainRunner(
|
|||||||
base_dir=tmp_base_dir,
|
base_dir=tmp_base_dir,
|
||||||
create_agent_fn=create_agent_fn,
|
create_agent_fn=create_agent_fn,
|
||||||
env=lts_gym_env,
|
env=lts_gym_env,
|
||||||
episode_log_file="",
|
episode_log_file="episode.log",
|
||||||
max_training_steps=5,
|
max_training_steps=5,
|
||||||
num_iterations=1
|
num_iterations=1
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user