Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ build: false
test_script:
- mkdir empty_folder
- cd empty_folder
# eventually `- python -m pytest ../tests`
- python -c "import httpstan"
- python -m pytest ../tests/test_bernoulli.py
61 changes: 42 additions & 19 deletions httpstan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import pathlib
import platform
import shutil
import string
import sys
import tempfile
Expand All @@ -31,6 +32,18 @@
logger = logging.getLogger("httpstan")


class TemporaryDirectory(tempfile.TemporaryDirectory):
"""Patch TemporaryDirectory to ignore errors."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def cleanup(self):
"""Ignore errors"""
if self._finalizer.detach():
shutil.rmtree(self.name, ignore_errors=True)


def calculate_model_name(program_code: str) -> str:
"""Calculate model name from Stan program code.

Expand Down Expand Up @@ -152,7 +165,7 @@ def import_model_extension_module(model_name: str, module_bytes: bytes):
module_filename = f"{module_name}.so"
assert isinstance(module_bytes, bytes)

with tempfile.TemporaryDirectory() as temporary_directory:
with TemporaryDirectory() as temporary_directory:
with open(os.path.join(temporary_directory, module_filename), "wb") as fh:
fh.write(module_bytes)
module_path = temporary_directory
Expand Down Expand Up @@ -193,21 +206,20 @@ def _build_extension_module(
"""
# write files need for compilation in a temporary directory which will be
# removed when this function exits.
with tempfile.TemporaryDirectory() as temporary_dir:
temporary_dir_path = pathlib.Path(temporary_dir)
cpp_filepath = temporary_dir_path / f"{module_name}.hpp"
pyx_filepath = temporary_dir_path / f"{module_name}.pyx"
with TemporaryDirectory() as temporary_dir:
temporary_dir = pathlib.Path(temporary_dir)
cpp_filepath = temporary_dir / f"{module_name}.hpp"
pyx_filepath = temporary_dir / f"{module_name}.pyx"
pyx_code = string.Template(pyx_code_template).substitute(
cpp_filename=cpp_filepath.as_posix()
)
for filepath, code in zip([cpp_filepath, pyx_filepath], [cpp_code, pyx_code]):
with open(filepath, "w") as fh:
fh.write(code)

httpstan_dir = os.path.dirname(__file__)
include_dirs = [
httpstan_dir, # for queue_writer.hpp and queue_logger.hpp
temporary_dir_path.as_posix(),
temporary_dir.as_posix(),
os.path.join(httpstan_dir, "lib", "stan", "src"),
os.path.join(httpstan_dir, "lib", "stan", "lib", "stan_math"),
os.path.join(httpstan_dir, "lib", "stan", "lib", "stan_math", "lib", "eigen_3.3.3"),
Expand All @@ -230,6 +242,7 @@ def _build_extension_module(
if platform.system() == "Windows":
# -D_hypot=hypot, MinGW fix, https://github.com/python/cpython/pull/11283
extra_compile_args.append("-D_hypot=hypot")
extra_compile_args.append("-static-libgcc")

cython_include_path = [os.path.dirname(httpstan_dir)]
extension = setuptools.Extension(
Expand Down Expand Up @@ -261,12 +274,18 @@ def _redirect_stdout() -> int:
orig_stderr: copy of original stderr file descriptor
"""
sys.stdout.flush()
stdout_fileno = sys.stdout.fileno()
orig_stdout = os.dup(stdout_fileno)
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fileno)
os.close(devnull)
return orig_stdout
if platform.system() == "Windows":
orig_stdout = sys.stdout
devnull = open(os.devnull, "w")
sys.stdout = devnull
return orig_stdout, devnull
else:
stdout_fileno = sys.stdout.fileno()
orig_stdout = os.dup(stdout_fileno)
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, stdout_fileno)
os.close(devnull)
return orig_stdout

def _redirect_stderr_to(stream: IO[Any]) -> int:
"""Redirect stderr for subprocesses to /dev/null.
Expand All @@ -289,12 +308,11 @@ def _redirect_stderr_to(stream: IO[Any]) -> int:
if redirect_stderr:
orig_stdout = _redirect_stdout()
orig_stderr = _redirect_stderr_to(stream)

try:
build_extension.extensions = Cython.Build.cythonize(
[extension], include_path=cython_include_path
)
build_extension.build_temp = build_extension.build_lib = temporary_dir_path.as_posix()
build_extension.build_temp = build_extension.build_lib = temporary_dir.as_posix()
build_extension.run()
finally:
if redirect_stderr:
Expand All @@ -303,9 +321,14 @@ def _redirect_stderr_to(stream: IO[Any]) -> int:
stream.close()
# restore
os.dup2(orig_stderr, sys.stderr.fileno())
os.dup2(orig_stdout, sys.stdout.fileno())

if platform.system() == "Windows":
orig_stdout, devnull = orig_stdout
sys.stdout = orig_stdout
devnull.close()
else:
os.dup2(orig_stderr, sys.stderr.fileno())
module = _import_module(module_name, build_extension.build_lib)
assert module.__name__ == module_name, (module.__name__, module_name)
with open(module.__file__, "rb") as fh: # type: ignore # see mypy#3062
assert module.__name__ == module_name, (module.__name__, module_name)
return fh.read(), compiler_output # type: ignore # see mypy#3062
module_bytes = fh.read() # type: ignore # see mypy#3062
return module_bytes, compiler_output