diff --git a/matplotlibcpp.h b/matplotlibcpp.h index 93a72be5..deaee1f4 100644 --- a/matplotlibcpp.h +++ b/matplotlibcpp.h @@ -553,6 +553,100 @@ void plot_surface(const std::vector<::std::vector> &x, Py_DECREF(kwargs); if (res) Py_DECREF(res); } + +template +void plot_surface(long figure, + const std::vector<::std::vector> &x, + const std::vector<::std::vector> &y, + const std::vector<::std::vector> &z, + const std::map &keywords = + std::map()) +{ + detail::_interpreter::get(); + + // We lazily load the modules here the first time this function is called + // because I'm not sure that we can assume "matplotlib installed" implies + // "mpl_toolkits installed" on all platforms, and we don't want to require + // it for people who don't need 3d plots. + static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; + if (!mpl_toolkitsmod) { + detail::_interpreter::get(); + + PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); + PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); + if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } + + mpl_toolkitsmod = PyImport_Import(mpl_toolkits); + Py_DECREF(mpl_toolkits); + if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } + + axis3dmod = PyImport_Import(axis3d); + Py_DECREF(axis3d); + if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } + } + + assert(x.size() == y.size()); + assert(y.size() == z.size()); + + // using numpy arrays + PyObject *xarray = detail::get_2darray(x); + PyObject *yarray = detail::get_2darray(y); + PyObject *zarray = detail::get_2darray(z); + + // construct positional args + PyObject *args = PyTuple_New(3); + PyTuple_SetItem(args, 0, xarray); + PyTuple_SetItem(args, 1, yarray); + PyTuple_SetItem(args, 2, zarray); + + // Build up the kw args. + PyObject *kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rstride", PyInt_FromLong(1)); + PyDict_SetItemString(kwargs, "cstride", PyInt_FromLong(1)); + + PyObject *python_colormap_coolwarm = PyObject_GetAttrString( + detail::_interpreter::get().s_python_colormap, "coolwarm"); + + PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm); + + for (std::map::const_iterator it = keywords.begin(); + it != keywords.end(); ++it) { + PyDict_SetItemString(kwargs, it->first.c_str(), + PyString_FromString(it->second.c_str())); + } + + PyObject *fig_args = PyTuple_New(1); + PyTuple_SetItem(fig_args, 0, PyLong_FromLong(figure)); + PyObject *fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, fig_args); + if(!fig) throw std::runtime_error("Access to figure failed."); + + PyObject *gca_kwargs = PyDict_New(); + PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); + + PyObject *gca = PyObject_GetAttrString(fig, "gca"); + if (!gca) throw std::runtime_error("No gca"); + Py_INCREF(gca); + PyObject *axis = PyObject_Call( + gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); + + if (!axis) throw std::runtime_error("No axis"); + Py_INCREF(axis); + + Py_DECREF(gca); + Py_DECREF(gca_kwargs); + + PyObject *plot_surface = PyObject_GetAttrString(axis, "plot_surface"); + if (!plot_surface) throw std::runtime_error("No surface"); + Py_INCREF(plot_surface); + PyObject *res = PyObject_Call(plot_surface, args, kwargs); + if (!res) throw std::runtime_error("failed surface"); + Py_DECREF(plot_surface); + + Py_DECREF(axis); + Py_DECREF(args); + Py_DECREF(kwargs); + if (res) Py_DECREF(res); +} #endif // WITHOUT_NUMPY template @@ -564,9 +658,9 @@ void plot3(const std::vector &x, { detail::_interpreter::get(); - // Same as with plot_surface: We lazily load the modules here the first time - // this function is called because I'm not sure that we can assume "matplotlib - // installed" implies "mpl_toolkits installed" on all platforms, and we don't + // Same as with plot_surface: We lazily load the modules here the first time + // this function is called because I'm not sure that we can assume "matplotlib + // installed" implies "mpl_toolkits installed" on all platforms, and we don't // want to require it for people who don't need 3d plots. static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; if (!mpl_toolkitsmod) {