Source code for retro.retro_env

import gc
import gzip
import json
import os

import gymnasium as gym
import numpy as np

import retro
import retro.data

__all__ = ["RetroEnv"]


[docs] class RetroEnv(gym.Env): """ Gym Retro environment class Provides a Gym interface to classic video games """ metadata = {"render_modes": ["human", "rgb_array"], "video.frames_per_second": 60.0} def __init__( self, game, state=retro.State.DEFAULT, scenario=None, info=None, use_restricted_actions=retro.Actions.FILTERED, record=False, players=1, inttype=retro.data.Integrations.STABLE, obs_type=retro.Observations.IMAGE, render_mode="human", ): if not hasattr(self, "spec"): self.spec = None self._obs_type = obs_type self.img = None self.ram = None self.viewer = None self.gamename = game self.statename = state self.initial_state = None self.players = players # Don't return multiple rewards in multiplayer mode by default # as stable-baselines3 vectorized environments doesn't support it self.multi_rewards = False metadata = {} rom_path = retro.data.get_romfile_path(game, inttype) metadata_path = retro.data.get_file_path(game, "metadata.json", inttype) if state == retro.State.NONE: self.statename = None elif state == retro.State.DEFAULT: self.statename = None try: with open(metadata_path) as f: metadata = json.load(f) if "default_player_state" in metadata and self.players <= len( metadata["default_player_state"], ): self.statename = metadata["default_player_state"][self.players - 1] elif "default_state" in metadata: self.statename = metadata["default_state"] else: self.statename = None except (OSError, json.JSONDecodeError): pass if self.statename: self.load_state(self.statename, inttype) self.data = retro.data.GameData() if info is None: info = "data" if info.endswith(".json"): # assume it's a path info_path = info else: info_path = retro.data.get_file_path(game, info + ".json", inttype) if scenario is None: scenario = "scenario" if scenario.endswith(".json"): # assume it's a path scenario_path = scenario else: scenario_path = retro.data.get_file_path(game, scenario + ".json", inttype) self.system = retro.get_romfile_system(rom_path) # We can't have more than one emulator per process. Before creating an # emulator, ensure that unused ones are garbage-collected gc.collect() self.em = retro.RetroEmulator(rom_path) self.em.configure_data(self.data) self.em.step() core = retro.get_system_info(self.system) self.buttons = core["buttons"] self.num_buttons = len(self.buttons) try: assert self.data.load( info_path, scenario_path, ), "Failed to load info ({}) or scenario ({})".format( info_path, scenario_path, ) except Exception: del self.em raise self.button_combos = self.data.valid_actions() if use_restricted_actions == retro.Actions.DISCRETE: combos = 1 for combo in self.button_combos: combos *= len(combo) self.action_space = gym.spaces.Discrete(combos**players) elif use_restricted_actions == retro.Actions.MULTI_DISCRETE: self.action_space = gym.spaces.MultiDiscrete( [len(combos) for combos in self.button_combos] * players, ) else: self.action_space = gym.spaces.MultiBinary(self.num_buttons * players) if self._obs_type == retro.Observations.RAM: shape = self.get_ram().shape else: img = [self.get_screen(p) for p in range(players)] shape = img[0].shape self.observation_space = gym.spaces.Box( low=0, high=255, shape=shape, dtype=np.uint8, ) self.use_restricted_actions = use_restricted_actions self.movie = None self.movie_id = 0 self.movie_path = None if record is True: self.auto_record() elif record is not False: self.auto_record(record) self.render_mode = render_mode def _update_obs(self): if self._obs_type == retro.Observations.RAM: self.ram = self.get_ram() return self.ram elif self._obs_type == retro.Observations.IMAGE: self.img = self.get_screen() return self.img else: raise ValueError(f"Unrecognized observation type: {self._obs_type}") def action_to_array(self, a): actions = [] for p in range(self.players): action = 0 if self.use_restricted_actions == retro.Actions.DISCRETE: for combo in self.button_combos: current = a % len(combo) a //= len(combo) action |= combo[current] elif self.use_restricted_actions == retro.Actions.MULTI_DISCRETE: ap = a[self.num_buttons * p : self.num_buttons * (p + 1)] for i in range(len(ap)): buttons = self.button_combos[i] action |= buttons[ap[i]] else: ap = a[self.num_buttons * p : self.num_buttons * (p + 1)] for i in range(len(ap)): action |= int(ap[i]) << i if self.use_restricted_actions == retro.Actions.FILTERED: action = self.data.filter_action(action) ap = np.zeros([self.num_buttons], np.uint8) for i in range(self.num_buttons): ap[i] = (action >> i) & 1 actions.append(ap) return actions def step(self, a): if self.img is None and self.ram is None: raise RuntimeError("Please call env.reset() before env.step()") for p, ap in enumerate(self.action_to_array(a)): if self.movie: for i in range(self.num_buttons): self.movie.set_key(i, ap[i], p) self.em.set_button_mask(ap, p) if self.movie: self.movie.step() self.em.step() self.data.update_ram() ob = self._update_obs() rew, done, info = self.compute_step() if self.render_mode == "human": self.render() return ob, rew, bool(done), False, dict(info) def reset(self, seed=None, options=None): super().reset(seed=seed) if self.initial_state: self.em.set_state(self.initial_state) for p in range(self.players): self.em.set_button_mask(np.zeros([self.num_buttons], np.uint8), p) self.em.step() if self.movie_path is not None: rel_statename = os.path.splitext(os.path.basename(self.statename))[0] self.record_movie( os.path.join( self.movie_path, "%s-%s-%06d.bk2" % (self.gamename, rel_statename, self.movie_id), ), ) self.movie_id += 1 if self.movie: self.movie.step() self.data.reset() self.data.update_ram() if self.render_mode == "human": self.render() return self._update_obs(), {} def render(self): mode = self.render_mode img = self.get_screen() if self.img is None else self.img if mode == "rgb_array": return img elif mode == "human": if self.viewer is None: from retro.rendering import SimpleImageViewer self.viewer = SimpleImageViewer() self.viewer.imshow(img) return self.viewer.isopen def close(self): if hasattr(self, "em"): del self.em if self.viewer: self.viewer.close() def get_action_meaning(self, act): actions = [] for p, action in enumerate(self.action_to_array(act)): actions.append( [self.buttons[i] for i in np.extract(action, np.arange(len(action)))], ) if self.players == 1: return actions[0] return actions def set_value(self, name, val): self.data.set_value(name, val) def get_ram(self): blocks = [] for offset in sorted(self.data.memory.blocks): arr = np.frombuffer(self.data.memory.blocks[offset], dtype=np.uint8) blocks.append(arr) return np.concatenate(blocks) def get_screen(self, player=0): img = self.em.get_screen() x, y, w, h = self.data.crop_info(player) if not w or x + w > img.shape[1]: w = img.shape[1] else: w += x if not h or y + h > img.shape[0]: h = img.shape[0] else: h += y if x == 0 and y == 0 and w == img.shape[1] and h == img.shape[0]: return img return img[y:h, x:w] def load_state(self, statename, inttype=retro.data.Integrations.DEFAULT): if not statename.endswith(".state"): statename += ".state" with gzip.open( retro.data.get_file_path(self.gamename, statename, inttype), "rb", ) as fh: self.initial_state = fh.read() self.statename = statename def compute_step(self): if self.players > 1 and self.multi_rewards: reward = [self.data.current_reward(p) for p in range(self.players)] else: reward = self.data.current_reward() done = self.data.is_done() return reward, done, self.data.lookup_all() def record_movie(self, path): self.movie = retro.Movie(path, True, self.players) self.movie.configure(self.gamename, self.em) if self.initial_state: self.movie.set_state(self.initial_state) def stop_record(self): self.movie_path = None self.movie_id = 0 if self.movie: self.movie.close() self.movie = None def auto_record(self, path=None): if not path: path = os.getcwd() self.movie_path = path