diff --git a/jit/codon/decorator.py b/jit/codon/decorator.py index f7a33ecd9..0cbce448e 100644 --- a/jit/codon/decorator.py +++ b/jit/codon/decorator.py @@ -146,6 +146,59 @@ def _codon_type(arg, **kwargs): def _codon_types(args, **kwargs): return tuple(_codon_type(arg, **kwargs) for arg in args) +def bind_args(func, args, kwargs, drop_self=False): + bound_self = None + if inspect.ismethod(func) and func.__self__ is not None: + bound_self = func.__self__ + func = func.__func__ + args = (bound_self, *args) + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + names = list(sig.parameters) + if drop_self and names and names[0] == "self": + names = names[1:] + return tuple(bound_args.arguments[p] for p in names) + +class JITCallable: + def __init__(self, py_func, obj_name, module, debug=0, sample_size=5, pyvars=None): + self.py_func = py_func + self.obj_name = obj_name + self.module = module + self.debug = debug + self.sample_size = sample_size + self.pyvars = pyvars or [] + def codon_types(self, args): + return _codon_types(args, debug=self.debug, sample_size=self.sample_size) + def bind_args(self, args, kwargs, drop_self=False): + if self.py_func is not None: + return bind_args(self.py_func, args, kwargs, drop_self) + return (*args, *kwargs.values()) + def reset_on_jit_error(self, fn): + try: + return fn() + except JITError: + _reset_jit() + raise + def __call__(self, *args, **kwargs): + def run(): + bound_args = self.bind_args(args, kwargs) + types = self.codon_types(bound_args) + if self.debug > 0: + print( + "[python] {}({})".format(self.obj_name, list(types)), + file=sys.stderr, + ) + return _jit.run_wrapper( + self.obj_name, + list(types), + self.module, + list(self.pyvars), + bound_args, + int(self.debug > 0), + ) + return self.reset_on_jit_error(run) + def _reset_jit(): global _jit _jit = JITWrapper() @@ -258,31 +311,16 @@ def _jit_callback_fn(fn, pyvars=None, *args, **kwargs): - if fn is not None: - sig = inspect.signature(fn) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - args = tuple(bound_args.arguments[param] for param in sig.parameters) - else: - args = (*args, *kwargs.values()) - - try: - types = _codon_types(args, debug=debug, sample_size=sample_size) - if debug > 0: - print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr) - return _jit.run_wrapper( - obj_name, list(types), module, list(pyvars), args, int(debug > 0) - ) - except JITError: - _reset_jit() - raise + return JITCallable(fn, obj_name, module, debug, sample_size, pyvars)( + *args, **kwargs + ) def _jit_str_fn(fstr, debug=0, sample_size=5, pyvars=None): obj_name = _jit_register_fn(fstr, pyvars, debug) + jit_func = JITCallable(None, obj_name, "__main__", debug, sample_size, pyvars) def wrapped(*args, **kwargs): - return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size, - pyvars, *args, **kwargs) + return jit_func(*args, **kwargs) return wrapped @@ -304,11 +342,11 @@ def jit(fn=None, debug=0, sample_size=5, pyvars=None): def _decorate(f): obj_name = _jit_register_fn(f, pyvars, debug) + jit_func = JITCallable(f, obj_name, f.__module__, debug, sample_size, pyvars) @functools.wraps(f) def wrapped(*args, **kwargs): - return _jit_callback_fn(f, obj_name, f.__module__, debug, - sample_size, pyvars, *args, **kwargs) + return jit_func(*args, **kwargs) return wrapped