Skip to content

Commit

Permalink
Merge pull request RustPython#4654 from youknowone/pycallble
Browse files Browse the repository at this point in the history
Add protocol object PyCallable
  • Loading branch information
youknowone authored Mar 7, 2023
2 parents bdce56d + b60271a commit 5a74f08
Show file tree
Hide file tree
Showing 59 changed files with 331 additions and 296 deletions.
10 changes: 4 additions & 6 deletions examples/call_between_rust_and_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ pub fn main() {

let module = vm.import("call_between_rust_and_python", None, 0).unwrap();
let init_fn = module.get_attr("python_callback", vm).unwrap();
vm.invoke(&init_fn, ()).unwrap();
init_fn.call((), vm).unwrap();

let take_string_fn = module.get_attr("take_string", vm).unwrap();
vm.invoke(
&take_string_fn,
(String::from("Rust string sent to python"),),
)
.unwrap();
take_string_fn
.call((String::from("Rust string sent to python"),), vm)
.unwrap();
})
}

Expand Down
2 changes: 1 addition & 1 deletion examples/package_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn py_main(interp: &Interpreter) -> vm::PyResult<PyStrRef> {
.expect("add path");
let module = vm.import("package_embed", None, 0)?;
let name_func = module.get_attr("context", vm)?;
let result = vm.invoke(&name_func, ())?;
let result = name_func.call((), vm)?;
let result: PyStrRef = result.get_attr("name", vm)?.try_into_value(vm)?;
vm::PyResult::Ok(result)
})
Expand Down
13 changes: 6 additions & 7 deletions stdlib/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
.get_attr("MutableSequence", vm)
.expect("Expect collections.abc has MutableSequence type.");

vm.invoke(
&mutable_sequence
.get_attr("register", vm)
.expect("Expect collections.abc.MutableSequence has register method."),
(array,),
)
.expect("Expect collections.abc.MutableSequence.register(array.array) not fail.");
let register = &mutable_sequence
.get_attr("register", vm)
.expect("Expect collections.abc.MutableSequence has register method.");
register
.call((array,), vm)
.expect("Expect collections.abc.MutableSequence.register(array.array) not fail.");

module
}
Expand Down
8 changes: 4 additions & 4 deletions stdlib/src/bisect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ mod _bisect {
let mid = (lo + hi) / 2;
let a_mid = a.get_item(&mid, vm)?;
let comp = if let Some(ref key) = key {
vm.invoke(key, (a_mid,))?
key.call((a_mid,), vm)?
} else {
a_mid
};
Expand All @@ -96,7 +96,7 @@ mod _bisect {
let mid = (lo + hi) / 2;
let a_mid = a.get_item(&mid, vm)?;
let comp = if let Some(ref key) = key {
vm.invoke(key, (a_mid,))?
key.call((a_mid,), vm)?
} else {
a_mid
};
Expand All @@ -112,7 +112,7 @@ mod _bisect {
#[pyfunction]
fn insort_left(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
let x = if let Some(ref key) = key {
vm.invoke(key, (x,))?
key.call((x,), vm)?
} else {
x
};
Expand All @@ -132,7 +132,7 @@ mod _bisect {
#[pyfunction]
fn insort_right(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
let x = if let Some(ref key) = key {
vm.invoke(key, (x,))?
key.call((x,), vm)?
} else {
x
};
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ mod _csv {
) -> PyResult<Writer> {
let write = match vm.get_attribute_opt(file.clone(), "write")? {
Some(write_meth) => write_meth,
None if vm.is_callable(&file) => file,
None if file.is_callable() => file,
None => {
return Err(vm.new_type_error("argument 1 must have a \"write\" method".to_owned()))
}
Expand Down Expand Up @@ -309,7 +309,7 @@ mod _csv {
let s = std::str::from_utf8(&buffer[..buffer_offset])
.map_err(|_| vm.new_unicode_decode_error("csv not utf8".to_owned()))?;

vm.invoke(&self.write, (s.to_owned(),))
self.write.call((s,), vm)
}

#[pymethod]
Expand Down
37 changes: 16 additions & 21 deletions stdlib/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mod machinery;
mod _json {
use super::machinery;
use crate::vm::{
builtins::{PyBaseExceptionRef, PyStrRef, PyTypeRef},
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef},
convert::{ToPyObject, ToPyResult},
function::OptionalArg,
function::{IntoFuncArgs, OptionalArg},
protocol::PyIterReturn,
types::{Callable, Constructor},
AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
Expand Down Expand Up @@ -91,25 +91,23 @@ mod _json {
'{' => {
// TODO: parse the object in rust
let parse_obj = self.ctx.get_attr("parse_object", vm)?;
return PyIterReturn::from_pyresult(
vm.invoke(
&parse_obj,
(
(pystr, next_idx),
self.strict,
scan_once,
self.object_hook.clone(),
self.object_pairs_hook.clone(),
),
let result = parse_obj.call(
(
(pystr, next_idx),
self.strict,
scan_once,
self.object_hook.clone(),
self.object_pairs_hook.clone(),
),
vm,
);
return PyIterReturn::from_pyresult(result, vm);
}
'[' => {
// TODO: parse the array in rust
let parse_array = self.ctx.get_attr("parse_array", vm)?;
return PyIterReturn::from_pyresult(
vm.invoke(&parse_array, ((pystr, next_idx), scan_once)),
parse_array.call(((pystr, next_idx), scan_once), vm),
vm,
);
}
Expand Down Expand Up @@ -138,11 +136,8 @@ mod _json {
($s:literal) => {
if s.starts_with($s) {
return Ok(PyIterReturn::Return(
vm.new_tuple((
vm.invoke(&self.parse_constant, ($s.to_owned(),))?,
idx + $s.len(),
))
.into(),
vm.new_tuple((self.parse_constant.call(($s,), vm)?, idx + $s.len()))
.into(),
));
}
};
Expand Down Expand Up @@ -181,12 +176,12 @@ mod _json {
let ret = if has_decimal || has_exponent {
// float
if let Some(ref parse_float) = self.parse_float {
vm.invoke(parse_float, (buf.to_owned(),))
parse_float.call((buf,), vm)
} else {
Ok(vm.ctx.new_float(f64::from_str(buf).unwrap()).into())
}
} else if let Some(ref parse_int) = self.parse_int {
vm.invoke(parse_int, (buf.to_owned(),))
parse_int.call((buf,), vm)
} else {
Ok(vm.new_pyobj(BigInt::from_str(buf).unwrap()))
};
Expand Down Expand Up @@ -243,7 +238,7 @@ mod _json {
) -> PyBaseExceptionRef {
let get_error = || -> PyResult<_> {
let cls = vm.try_class("json", "JSONDecodeError")?;
let exc = vm.invoke(&cls, (e.msg, s, e.pos))?;
let exc = PyType::call(&cls, (e.msg, s, e.pos).into_args(vm), vm)?;
exc.try_into_value(vm)
};
match get_error() {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ mod math {
func_name.as_str(),
)
})?;
vm.invoke(&method, ())
method.call((), vm)
}

#[pyfunction]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/pyexpat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod _pyexpat {
where
T: IntoFuncArgs,
{
vm.invoke(&handler.read().clone(), args).ok();
handler.read().call(args, vm).ok();
}

#[pyclass]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl TryFromObject for Selectable {
vm.ctx.interned_str("fileno").unwrap(),
|| "select arg must be an int or object with a fileno() method".to_owned(),
)?;
vm.invoke(&meth, ())?.try_into_value(vm)
meth.call((), vm)?.try_into_value(vm)
})?;
Ok(Selectable { obj, fno })
}
Expand Down
28 changes: 15 additions & 13 deletions stdlib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ mod _sqlite {
.map(|val| value_to_object(val, db, vm))
.collect::<PyResult<Vec<PyObjectRef>>>()?;

let val = vm.invoke(func, args)?;
let val = func.call(args, vm)?;

context.result_from_object(&val, vm)
};
Expand All @@ -410,7 +410,7 @@ mod _sqlite {
let args = std::slice::from_raw_parts(argv, argc as usize);
let instance = context.aggregate_context::<*const PyObject>();
if (*instance).is_null() {
match vm.invoke(cls, ()) {
match cls.call((), vm) {
Ok(obj) => *instance = obj.into_raw(),
Err(exc) => {
return context.result_exception(
Expand Down Expand Up @@ -450,7 +450,7 @@ mod _sqlite {
let text2 = ptr_to_string(b_ptr.cast(), b_len, null_mut(), vm)?;
let text2 = vm.ctx.new_str(text2);

let val = vm.invoke(callable, (text1, text2))?;
let val = callable.call((text1, text2), vm)?;
let Some(val) = val.to_number().index(vm) else {
return Ok(0);
};
Expand Down Expand Up @@ -505,7 +505,7 @@ mod _sqlite {
let db_name = ptr_to_str(db_name, vm)?;
let access = ptr_to_str(access, vm)?;

let val = vm.invoke(callable, (action, arg1, arg2, db_name, access))?;
let val = callable.call((action, arg1, arg2, db_name, access), vm)?;
let Some(val) = val.payload::<PyInt>() else {
return Ok(SQLITE_DENY);
};
Expand All @@ -525,15 +525,16 @@ mod _sqlite {
let expanded = sqlite3_expanded_sql(stmt.cast());
let f = || -> PyResult<()> {
let stmt = ptr_to_str(expanded, vm).or_else(|_| ptr_to_str(sql.cast(), vm))?;
vm.invoke(callable, (stmt,)).map(drop)
callable.call((stmt,), vm)?;
Ok(())
};
let _ = f();
0
}

unsafe extern "C" fn progress_callback(data: *mut c_void) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrive();
if let Ok(val) = vm.invoke(callable, ()) {
if let Ok(val) = callable.call((), vm) {
if let Ok(val) = val.is_true(vm) {
return val as c_int;
}
Expand Down Expand Up @@ -661,10 +662,10 @@ mod _sqlite {
.new_tuple(vec![obj.class().to_owned().into(), proto.clone()]);

if let Some(adapter) = adapters().get_item_opt(key.as_object(), vm)? {
return vm.invoke(&adapter, (obj,));
return adapter.call((obj,), vm);
}
if let Ok(adapter) = proto.get_attr("__adapt__", vm) {
match vm.invoke(&adapter, (obj,)) {
match adapter.call((obj,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
Expand All @@ -674,7 +675,7 @@ mod _sqlite {
}
}
if let Ok(adapter) = obj.get_attr("__conform__", vm) {
match vm.invoke(&adapter, (proto,)) {
match adapter.call((proto,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
Expand Down Expand Up @@ -1228,7 +1229,7 @@ mod _sqlite {
fn iterdump(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let module = vm.import("sqlite3.dump", None, 0)?;
let func = module.get_attr("_iterdump", vm)?;
vm.invoke(&func, (zelf,))
func.call((zelf,), vm)
}

#[pymethod]
Expand Down Expand Up @@ -1699,7 +1700,7 @@ mod _sqlite {
std::slice::from_raw_parts(blob.cast::<u8>(), nbytes as usize)
};
let blob = vm.ctx.new_bytes(blob.to_vec());
vm.invoke(&converter, (blob,))?
converter.call((blob,), vm)?
}
} else {
let col_type = st.column_type(i);
Expand All @@ -1724,7 +1725,7 @@ mod _sqlite {
PyByteArray::from(text).into_ref(vm).into()
} else {
let bytes = vm.ctx.new_bytes(text);
vm.invoke(&text_factory, (bytes,))?
text_factory.call((bytes,), vm)?
}
}
SQLITE_BLOB => {
Expand Down Expand Up @@ -1765,7 +1766,8 @@ mod _sqlite {
let row = vm.ctx.new_tuple(row);

if let Some(row_factory) = zelf.row_factory.to_owned() {
vm.invoke(&row_factory, (zelf.to_owned(), row))
row_factory
.call((zelf.to_owned(), row), vm)
.map(PyIterReturn::Return)
} else {
Ok(PyIterReturn::Return(row.into()))
Expand Down
4 changes: 2 additions & 2 deletions vm/src/builtins/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PyObjectRef {
Some(method_or_err) => {
// If descriptor returns Error, propagate it further
let method = method_or_err?;
let bool_obj = vm.invoke(&method, ())?;
let bool_obj = method.call((), vm)?;
if !bool_obj.fast_isinstance(vm.ctx.types.bool_type) {
return Err(vm.new_type_error(format!(
"__bool__ should return bool, returned type {}",
Expand All @@ -50,7 +50,7 @@ impl PyObjectRef {
None => match vm.get_method(self, identifier!(vm, __len__)) {
Some(method_or_err) => {
let method = method_or_err?;
let bool_obj = vm.invoke(&method, ())?;
let bool_obj = method.call((), vm)?;
let int_obj = bool_obj.payload::<PyInt>().ok_or_else(|| {
vm.new_type_error(format!(
"'{}' object cannot be interpreted as an integer",
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/classmethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl GetDescriptor for PyClassMethod {
let call_descr_get: PyResult<PyObjectRef> = zelf.callable.lock().get_attr("__get__", vm);
match call_descr_get {
Err(_) => Ok(PyBoundMethod::new_ref(cls, zelf.callable.lock().clone(), &vm.ctx).into()),
Ok(call_descr_get) => vm.invoke(&call_descr_get, (cls.clone(), cls)),
Ok(call_descr_get) => call_descr_get.call((cls.clone(), cls), vm),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl PyObjectRef {
return Ok(Some((complex.value, true)));
}
if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) {
let result = vm.invoke(&method?, ())?;
let result = method?.call((), vm)?;
// TODO: returning strict subclasses of complex in __complex__ is deprecated
return match result.payload::<PyComplex>() {
Some(complex_obj) => Ok(Some((complex_obj.value, true))),
Expand Down
4 changes: 2 additions & 2 deletions vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl PyDict {
};
let dict = &self.entries;
if let Some(keys) = vm.get_method(other.clone(), vm.ctx.intern_str("keys")) {
let keys = vm.invoke(&keys?, ())?.get_iter(vm)?;
let keys = keys?.call((), vm)?.get_iter(vm)?;
while let PyIterReturn::Return(key) = keys.next(vm)? {
let val = other.get_item(&*key, vm)?;
dict.insert(vm, &*key, val)?;
Expand Down Expand Up @@ -529,7 +529,7 @@ impl Py<PyDict> {
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
vm.get_method(self.to_owned().into(), identifier!(vm, __missing__))
.map(|methods| vm.invoke(&methods?, (key.to_pyobject(vm),)))
.map(|methods| methods?.call((key.to_pyobject(vm),), vm))
.transpose()
}

Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl IterNext for PyFilter {
} else {
// the predicate itself can raise StopIteration which does stop the filter
// iteration
match PyIterReturn::from_pyresult(vm.invoke(predicate, (next_obj.clone(),)), vm)? {
match PyIterReturn::from_pyresult(predicate.call((next_obj.clone(),), vm), vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
}
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ impl Callable for PyBoundMethod {
#[inline]
fn call(zelf: &crate::Py<Self>, mut args: FuncArgs, vm: &VirtualMachine) -> PyResult {
args.prepend_arg(zelf.object.clone());
vm.invoke(&zelf.function, args)
zelf.function.call(args, vm)
}
}

Expand Down
Loading

0 comments on commit 5a74f08

Please sign in to comment.