Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions jit/codon/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,59 @@ def _codon_type(arg, **kwargs):
def _codon_types(args, **kwargs):
return tuple(_codon_type(arg, **kwargs) for arg in args)

def bind_args(func, args, kwargs, drop_self=False):
bound_self = None
if inspect.ismethod(func) and func.__self__ is not None:
bound_self = func.__self__
func = func.__func__
args = (bound_self, *args)
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
names = list(sig.parameters)
if drop_self and names and names[0] == "self":
names = names[1:]
return tuple(bound_args.arguments[p] for p in names)

class JITCallable:
def __init__(self, py_func, obj_name, module, debug=0, sample_size=5, pyvars=None):
self.py_func = py_func
self.obj_name = obj_name
self.module = module
self.debug = debug
self.sample_size = sample_size
self.pyvars = pyvars or []
def codon_types(self, args):
return _codon_types(args, debug=self.debug, sample_size=self.sample_size)
def bind_args(self, args, kwargs, drop_self=False):
if self.py_func is not None:
return bind_args(self.py_func, args, kwargs, drop_self)
return (*args, *kwargs.values())
def reset_on_jit_error(self, fn):
try:
return fn()
except JITError:
_reset_jit()
raise
def __call__(self, *args, **kwargs):
def run():
bound_args = self.bind_args(args, kwargs)
types = self.codon_types(bound_args)
if self.debug > 0:
print(
"[python] {}({})".format(self.obj_name, list(types)),
file=sys.stderr,
)
return _jit.run_wrapper(
self.obj_name,
list(types),
self.module,
list(self.pyvars),
bound_args,
int(self.debug > 0),
)
return self.reset_on_jit_error(run)

def _reset_jit():
global _jit
_jit = JITWrapper()
Expand Down Expand Up @@ -258,31 +311,16 @@ def _jit_callback_fn(fn,
pyvars=None,
*args,
**kwargs):
if fn is not None:
sig = inspect.signature(fn)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
args = tuple(bound_args.arguments[param] for param in sig.parameters)
else:
args = (*args, *kwargs.values())

try:
types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug > 0:
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
return _jit.run_wrapper(
obj_name, list(types), module, list(pyvars), args, int(debug > 0)
)
except JITError:
_reset_jit()
raise
return JITCallable(fn, obj_name, module, debug, sample_size, pyvars)(
*args, **kwargs
)

def _jit_str_fn(fstr, debug=0, sample_size=5, pyvars=None):
obj_name = _jit_register_fn(fstr, pyvars, debug)
jit_func = JITCallable(None, obj_name, "__main__", debug, sample_size, pyvars)

def wrapped(*args, **kwargs):
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
pyvars, *args, **kwargs)
return jit_func(*args, **kwargs)

return wrapped

Expand All @@ -304,11 +342,11 @@ def jit(fn=None, debug=0, sample_size=5, pyvars=None):

def _decorate(f):
obj_name = _jit_register_fn(f, pyvars, debug)
jit_func = JITCallable(f, obj_name, f.__module__, debug, sample_size, pyvars)

@functools.wraps(f)
def wrapped(*args, **kwargs):
return _jit_callback_fn(f, obj_name, f.__module__, debug,
sample_size, pyvars, *args, **kwargs)
return jit_func(*args, **kwargs)

return wrapped

Expand Down