-
Vicki Pfau authoredVicki Pfau authored
retro.cpp 9.57 KiB
#include <cmath>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "coreinfo.h"
#include "data.h"
#include "emulator.h"
#include "imageops.h"
#include "memory.h"
#include "script.h"
#include "movie.h"
#include "movie-bk2.h"
#include <map>
#include <unordered_map>
#include <unordered_set>
namespace py = pybind11;
using std::string;
using namespace Retro;
struct PyGameData;
struct PyRetroEmulator {
Retro::Emulator m_re;
int m_cheats = 0;
PyRetroEmulator(const string& rom_path) {
if (Emulator::isLoaded()) {
throw std::runtime_error("Cannot create multiple emulator instances per process");
}
if (!m_re.loadRom(rom_path.c_str())) {
throw std::runtime_error("Could not load ROM");
}
m_re.run(); // otherwise you get a segfault when you try to get screen for the first time
}
void step() {
m_re.run();
}
py::bytes getState() {
size_t size = m_re.serializeSize();
py::bytes bytes(NULL, size);
m_re.serialize(PyBytes_AsString(bytes.ptr()), size);
return bytes;
}
bool setState(py::bytes o) {
return m_re.unserialize(PyBytes_AsString(o.ptr()), PyBytes_Size(o.ptr()));
}
py::array_t<uint8_t> getScreen() {
long w = m_re.getImageWidth();
long h = m_re.getImageHeight();
py::array_t<uint8_t> arr({ { h, w, 3 } });
uint8_t* data = arr.mutable_data();
Image out(Image::Format::RGB888, data, w, h, w);
Image in;
if (m_re.getImageDepth() == 16) {
in = Image(Image::Format::RGB565, m_re.getImageData(), w, h, m_re.getImagePitch());
} else if (m_re.getImageDepth() == 32) {
in = Image(Image::Format::RGBX888, m_re.getImageData(), w, h, m_re.getImagePitch());
}
in.copyTo(&out);
return arr;
}
double getScreenRate() {
return m_re.getFrameRate();
}
py::array_t<int16_t> getAudio() {
py::array_t<int16_t> arr(py::array::ShapeContainer{ m_re.getAudioSamples(), 2 });
int16_t* data = arr.mutable_data();
memcpy(data, m_re.getAudioData(), m_re.getAudioSamples() * 4);
return arr;
}
double getAudioRate() {
return m_re.getAudioRate();
}
py::tuple getResolution() {
return py::make_tuple(m_re.getImageWidth(), m_re.getImageHeight());
}
void setButtonMask(py::array_t<uint8_t> mask, int player) {
if (mask.size() > N_BUTTONS) {
throw std::runtime_error("mask.size() > N_BUTTONS");
}
for (int key = 0; key < mask.size(); ++key) {
m_re.setKey(player, key, mask.data()[key]);
}
}
void addCheat(const string& code) {
m_re.setCheat(m_cheats, true, code.c_str());
++m_cheats;
}
void clearCheats() {
m_re.clearCheats();
m_cheats = 0;
}
void configureData(PyGameData& data);
static bool loadCoreInfo(const string& json) {
return Retro::loadCoreInfo(json);
}
};
struct PyMemoryView {
Retro::MemoryView<> m_mem;
PyMemoryView(py::array_t<uint8_t>& mem) {
m_mem.open(static_cast<void*>(mem.mutable_data()), mem.size());
m_mem.clone();
}
int64_t extract(size_t address, const string& type, py::list ospec) {
MemoryOverlay overlay{
ospec.is_none() ? '=' : static_cast<string>(py::str(ospec[0]))[0],
ospec.is_none() ? '=' : static_cast<string>(py::str(ospec[1]))[0],
ospec.is_none() ? 1 : static_cast<size_t>(py::int_(ospec[2])),
};
return DataType(type)(m_mem.offset(0), address, overlay);
}
void assign(size_t address, const string& type, int64_t value, py::list ospec) {
MemoryOverlay overlay{
ospec.is_none() ? '=' : static_cast<string>(py::str(ospec[0]))[0],
ospec.is_none() ? '=' : static_cast<string>(py::str(ospec[1]))[0],
ospec.is_none() ? 1 : static_cast<size_t>(py::int_(ospec[2])),
};
DataType{ type }(m_mem.offset(0), address, overlay) = value;
}
py::array_t<uint8_t> data() {
return py::array_t<uint8_t>(m_mem.size(), &m_mem[0]);
}
};
struct PyGameData {
Retro::GameData m_data;
Retro::Scenario m_scen{ m_data };
bool load(py::handle data = py::none(), py::handle scen = py::none()) {
ScriptContext::reset();
bool success = true;
if (!data.is_none()) {
success = success && m_data.load(py::str(data));
}
if (!scen.is_none()) {
success = success && m_scen.load(py::str(scen));
}
return success;
}
bool save(py::handle data = py::none(), py::handle scen = py::none()) {
bool success = true;
if (!data.is_none()) {
success = success && m_data.save(py::str(data));
}
if (!scen.is_none()) {
success = success && m_scen.save(py::str(scen));
}
return success;
}
void reset() {
m_scen.reloadScripts();
}
uint16_t filterAction(uint16_t action) const {
return m_scen.filterAction(action);
}
py::list validActions() const {
py::list outer;
for (const auto& action : m_scen.validActions()) {
py::list inner;
for (const auto& act : action.second) {
inner.append(act);
}
outer.append(inner);
}
return outer;
}
void updateRam() {
m_data.updateRam();
}
py::dict lookupAll() const {
py::dict data;
for (const auto& var : m_data.lookupAll()) {
data[py::str(var.first)] = var.second;
}
return data;
}
py::dict getVariable(py::str name) const {
py::dict obj;
Retro::Variable var = m_data.getVariable(name);
obj["address"] = var.address;
obj["type"] = var.type.type;
return obj;
}
void setVariable(py::str name, py::dict obj) {
Retro::Variable var{ string(py::str(obj["type"])), py::int_(obj["address"]) };
m_data.setVariable(name, var);
}
void removeVariable(py::str name) {
m_data.removeVariable(name);
}
py::dict listVariables() {
const auto& vars = m_data.listVariables();
py::dict vdict;
for (const auto& var : vars) {
const auto& v = var.second;
vdict[py::str(var.first)] = py::dict(py::arg("address") = v.address, py::arg("type") = v.type.type);
}
return vdict;
}
float currentReward() const {
return m_scen.currentReward();
}
bool isDone() const {
return m_scen.isDone();
}
};
void PyRetroEmulator::configureData(PyGameData& data) {
m_re.configureData(&data.m_data);
}
struct PyMovie {
std::unique_ptr<Retro::Movie> m_movie;
bool recording = false;
PyMovie(py::str name, bool record) {
recording = record;
if (record) {
m_movie = std::make_unique<MovieBK2>(name, true);
} else {
m_movie = Movie::load(name);
}
if (!m_movie) {
throw std::runtime_error("Could not load movie");
}
}
void configure(py::str name, const PyRetroEmulator& emu) {
if (recording) {
static_cast<MovieBK2*>(m_movie.get())->setGameName(name);
static_cast<MovieBK2*>(m_movie.get())->loadKeymap(emu.m_re.core());
}
}
py::str getGameName() const {
return m_movie->getGameName();
}
bool step() {
return m_movie->step();
}
void close() {
m_movie->close();
}
bool getKey(int key) {
return m_movie->getKey(key);
}
void setKey(int key, bool set) {
return m_movie->setKey(key, set);
}
py::bytes getState() {
std::vector<uint8_t> data;
m_movie->getState(&data);
return py::bytes(reinterpret_cast<const char*>(data.data()), data.size());
}
void setState(py::bytes data) {
m_movie->setState(reinterpret_cast<uint8_t*>(PyBytes_AsString(data.ptr())), PyBytes_Size(data.ptr()));
}
};
py::str corePath(py::handle hint = py::none()) {
return Retro::corePath(py::str(hint));
}
py::str dataPath(py::handle hint = py::none()) {
return Retro::GameData::dataPath(py::str(hint));
}
PYBIND11_MODULE(_retro, m) {
m.doc() = "libretro bindings";
py::class_<PyRetroEmulator>(m, "RetroEmulator")
.def(py::init<const string&>())
.def("step", &PyRetroEmulator::step)
.def("set_button_mask", &PyRetroEmulator::setButtonMask, py::arg("mask"), py::arg("player") = 0)
.def("get_state", &PyRetroEmulator::getState)
.def("set_state", &PyRetroEmulator::setState)
.def("get_screen", &PyRetroEmulator::getScreen)
.def("get_screen_rate", &PyRetroEmulator::getScreenRate)
.def("get_audio", &PyRetroEmulator::getAudio)
.def("get_audio_rate", &PyRetroEmulator::getAudioRate)
.def("get_resolution", &PyRetroEmulator::getResolution)
.def("configure_data", &PyRetroEmulator::configureData)
.def("add_cheat", &PyRetroEmulator::addCheat)
.def("clear_cheats", &PyRetroEmulator::clearCheats)
.def_static("load_core_info", &PyRetroEmulator::loadCoreInfo);
py::class_<PyMemoryView>(m, "Memory")
.def(py::init<py::array_t<uint8_t>&>())
.def("extract", &PyMemoryView::extract, py::arg("address"), py::arg("type"), py::arg("overlay") = py::none())
.def("assign", &PyMemoryView::assign, py::arg("address"), py::arg("type"), py::arg("value"), py::arg("overlay") = py::none())
.def("data", &PyMemoryView::data);
py::class_<PyGameData>(m, "GameDataGlue")
.def(py::init<>())
.def("load", &PyGameData::load, py::arg("data") = py::none(), py::arg("scen") = py::none())
.def("save", &PyGameData::save, py::arg("data") = py::none(), py::arg("scen") = py::none())
.def("reset", &PyGameData::reset)
.def("filter_action", &PyGameData::filterAction)
.def("valid_actions", &PyGameData::validActions)
.def("update_ram", &PyGameData::updateRam)
.def("lookup_all", &PyGameData::lookupAll)
.def("get_variable", &PyGameData::getVariable)
.def("set_variable", &PyGameData::setVariable)
.def("remove_variable", &PyGameData::removeVariable)
.def("list_variables", &PyGameData::listVariables)
.def("current_reward", &PyGameData::currentReward)
.def("is_done", &PyGameData::isDone);
py::class_<PyMovie>(m, "Movie")
.def(py::init<py::str, bool>(), py::arg("path"), py::arg("record") = false)
.def("configure", &PyMovie::configure)
.def("get_game", &PyMovie::getGameName)
.def("step", &PyMovie::step)
.def("close", &PyMovie::close)
.def("get_key", &PyMovie::getKey)
.def("set_key", &PyMovie::setKey)
.def("get_state", &PyMovie::getState)
.def("set_state", &PyMovie::setState);
m.def("core_path", &::corePath, py::arg("hint") = py::none());
m.def("data_path", &::dataPath, py::arg("hint") = py::none());
}