import gc import gym import gzip import gym.spaces import numpy as np import os import retro from gym.utils import seeding from retro.data import GameData gym_version = tuple(int(x) for x in gym.__version__.split('.')) class RetroEnv(gym.Env): metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 60.0} def compute_step(self, image): reward = self.data.current_reward() done = self.data.is_done() return reward, done, self.data.lookup_all() def record_movie(self, path): self.movie = retro.Movie(path, True) self.movie.configure(self.gamename, self.em) if self.initial_state: self.movie.set_state(self.initial_state) def stop_record(self): self.movie_path = None self.movie_id = 0 if self.movie: self.movie.close() self.movie = None def auto_record(self, path=None): if not path: path = os.getcwd() self.movie_path = path def __init__(self, game, state=None, scenario=None, info=None, use_restricted_actions=retro.ACTIONS_FILTERED, record=False): if not hasattr(self, 'spec'): self.spec = None self.img = None self.viewer = None self.gamename = game self.statename = state game_path = retro.get_game_path(game) rom_path = retro.get_romfile_path(game) if state is None: self.initial_state = None else: if not state.endswith('.state'): state += '.state' with gzip.open(os.path.join(game_path, state), 'rb') as fh: self.initial_state = fh.read() self.data = GameData() if info is None: info = 'data' if info.endswith('.json'): # assume it's a path info_path = info else: info_path = os.path.join(game_path, info + '.json') if scenario is None: scenario = 'scenario' if scenario.endswith('.json'): # assume it's a path scenario_path = scenario else: scenario_path = os.path.join(game_path, scenario + '.json') system = retro.get_romfile_system(rom_path) # We can't have more than one emulator per process. Before creating an # emulator, ensure that unused ones are garbage-collected gc.collect() self.em = retro.RetroEmulator(rom_path) self.em.configure_data(self.data) self.em.step() img = self.em.get_screen() core = retro.get_system_info(system) self.BUTTONS = core['buttons'] self.NUM_BUTTONS = len(self.BUTTONS) self.BUTTON_COMBOS = self.data.valid_actions() try: assert self.data.load(info_path, scenario_path), 'Failed to load info (%s) or scenario (%s)' % (info_path, scenario_path) except Exception: del self.em raise if use_restricted_actions == retro.ACTIONS_DISCRETE: combos = 1 for combo in self.BUTTON_COMBOS: combos *= len(combo) self.action_space = gym.spaces.Discrete(combos) elif use_restricted_actions == retro.ACTIONS_MULTI_DISCRETE: self.action_space = gym.spaces.MultiDiscrete([len(combos) if gym_version >= (0, 9, 6) else (0, len(combos) - 1) for combos in self.BUTTON_COMBOS]) else: self.action_space = gym.spaces.MultiBinary(self.NUM_BUTTONS) kwargs = {} if gym_version >= (0, 9, 6): kwargs['dtype'] = np.uint8 self.observation_space = gym.spaces.Box(low=0, high=255, shape=img.shape, **kwargs) self.use_restricted_actions = use_restricted_actions self.movie = None self.movie_id = 0 self.movie_path = None if record is True: self.auto_record() elif record is not False: self.auto_record(record) self.seed() if gym_version < (0, 9, 6): self._seed = self.seed self._step = self.step self._reset = self.reset self._render = self.render self._close = self.close def step(self, a): action = 0 if self.use_restricted_actions == retro.ACTIONS_DISCRETE: for combo in self.BUTTON_COMBOS: current = a % len(combo) a //= len(combo) action |= combo[current] elif self.use_restricted_actions == retro.ACTIONS_MULTI_DISCRETE: for i in range(len(a)): buttons = self.BUTTON_COMBOS[i] action |= buttons[a[i]] else: for i in range(len(a)): action |= int(a[i]) << i if self.use_restricted_actions == retro.ACTIONS_FILTERED: action = self.data.filter_action(action) a = np.zeros([16], np.uint8) for i in range(16): a[i] = (action >> i) & 1 if self.movie: self.movie.set_key(i, a[i]) if self.movie: self.movie.step() self.em.set_button_mask(a) self.em.step() self.img = ob = self.em.get_screen() self.data.update_ram() rew, done, info = self.compute_step(ob) return ob, float(rew), bool(done), dict(info) def reset(self): if self.initial_state: self.em.set_state(self.initial_state) self.em.set_button_mask(np.zeros([16], np.uint8)) self.em.step() if self.movie_path is not None: self.record_movie(os.path.join(self.movie_path, '%s-%s-%04d.bk2' % (self.gamename, self.statename, self.movie_id))) self.movie_id += 1 if self.movie: self.movie.step() self.img = ob = self.em.get_screen() self.data.reset() self.data.update_ram() return ob def seed(self, seed=None): self.np_random, seed1 = seeding.np_random(seed) # Derive a random seed. This gets passed as a uint, but gets # checked as an int elsewhere, so we need to keep it below # 2**31. seed2 = seeding.hash_seed(seed1 + 1) % 2**31 return [seed1, seed2] def render(self, mode='human', close=False): if close: if self.viewer: self.viewer.close() return if mode == "rgb_array": return self.em.get_screen() if self.img is None else self.img elif mode == "human": if self.viewer is None: from gym.envs.classic_control.rendering import SimpleImageViewer self.viewer = SimpleImageViewer() self.viewer.imshow(self.img) return self.viewer.isopen def close(self): if hasattr(self, 'em'): del self.em