Skip to content
Merged

Gc #22

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ func ToValue(from Object, to reflect.Value) bool {
t := to.Type()
to.Set(reflect.MakeMap(t))
dict := cast[Dict](from)
for key, value := range dict.Items() {
iter := dict.Iter()
for iter.HasNext() {
key, value := iter.Next()
vk := reflect.New(t.Key()).Elem()
vv := reflect.New(t.Elem()).Elem()
if !ToValue(key, vk) || !ToValue(value, vv) {
Expand Down
42 changes: 21 additions & 21 deletions dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,26 @@ func (d Dict) Del(key Objecter) {
C.PyDict_DelItem(d.obj, key.Obj())
}

func (d Dict) Items() func(fn func(key, value Object) bool) {
return func(fn func(key, value Object) bool) {
items := C.PyDict_Items(d.obj)
check(items != nil, "failed to get items of dict")
defer C.Py_DecRef(items)
iter := C.PyObject_GetIter(items)
for {
item := C.PyIter_Next(iter)
if item == nil {
break
}
C.Py_IncRef(item)
key := C.PyTuple_GetItem(item, 0)
value := C.PyTuple_GetItem(item, 1)
C.Py_IncRef(key)
C.Py_IncRef(value)
C.Py_DecRef(item)
if !fn(newObject(key), newObject(value)) {
break
}
}
func (d Dict) Iter() *DictIter {
return &DictIter{dict: d, pos: 0}
}

type DictIter struct {
dict Dict
pos C.long
}

func (d *DictIter) HasNext() bool {
pos := d.pos
return C.PyDict_Next(d.dict.obj, &pos, nil, nil) != 0
}

func (d *DictIter) Next() (Object, Object) {
var key, value *C.PyObject
if C.PyDict_Next(d.dict.obj, &d.pos, &key, &value) == 0 {
return Nil(), Nil()
}
C.Py_IncRef(key)
C.Py_IncRef(value)
return newObject(key), newObject(value)
}
4 changes: 3 additions & 1 deletion dict_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ func TestDictForEach(t *testing.T) {
"key3": "value3",
}

for key, value := range dict.Items() {
iter := dict.Iter()
for iter.HasNext() {
key, value := iter.Next()
count++
k := key.String()
v := value.String()
Expand Down
104 changes: 33 additions & 71 deletions extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type slotMeta struct {
hasRecv bool // whether it has a receiver
index int // used for member type
typ reflect.Type // member/method type
def *C.PyMethodDef
}

type typeMeta struct {
Expand All @@ -60,9 +61,7 @@ type typeMeta struct {

func allocWrapper(typ *C.PyTypeObject, obj any) *wrapperType {
self := C.PyType_GenericAlloc(typ, 0)
if self == nil {
return nil
}
check(self != nil, "failed to allocate wrapper")
wrapper := (*wrapperType)(unsafe.Pointer(self))
holder := new(objectHolder)
holder.obj = obj
Expand All @@ -83,9 +82,7 @@ func wrapperAlloc(typ *C.PyTypeObject, size C.Py_ssize_t) *C.PyObject {
maps := getGlobalData()
meta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(typ))]
wrapper := allocWrapper(typ, reflect.New(meta.typ).Interface())
if wrapper == nil {
return nil
}
check(wrapper != nil, "failed to allocate wrapper")
return (*C.PyObject)(unsafe.Pointer(wrapper))
}

Expand All @@ -101,9 +98,8 @@ func wrapperInit(self, args *C.PyObject) C.int {
typ := (*C.PyObject)(self).ob_type
maps := getGlobalData()
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(typ))]
if typeMeta.init == nil {
return 0
}
check(typeMeta != nil, "type not registered")
check(typeMeta.init != nil, "init method not found")
if wrapperMethod_(typeMeta, typeMeta.init, self, args, 0) == nil {
return -1
}
Expand All @@ -114,15 +110,9 @@ func wrapperInit(self, args *C.PyObject) C.int {
func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.PyObject {
maps := getGlobalData()
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
if typeMeta == nil {
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
return nil
}
check(typeMeta != nil, fmt.Sprintf("type %v not registered", FromPy(self)))
methodMeta := typeMeta.methods[uint(methodId)]
if methodMeta == nil {
SetError(fmt.Errorf("getter method %d not found", methodId))
return nil
}
check(methodMeta != nil, fmt.Sprintf("getter method %d not found", methodId))

wrapper := (*wrapperType)(unsafe.Pointer(self))
goPtr := reflect.ValueOf(wrapper.goObj)
Expand All @@ -136,10 +126,7 @@ func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.
}
if pyType, ok := maps.pyTypes[fieldType.Elem()]; ok {
newWrapper := allocWrapper((*C.PyTypeObject)(unsafe.Pointer(pyType)), field.Interface())
if newWrapper == nil {
SetError(fmt.Errorf("failed to allocate wrapper for nested struct pointer"))
return nil
}
check(newWrapper != nil, "failed to allocate wrapper for nested struct pointer")
return (*C.PyObject)(unsafe.Pointer(newWrapper))
}
} else if field.Kind() == reflect.Struct {
Expand All @@ -148,10 +135,7 @@ func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.
fieldAddr := unsafe.Add(baseAddr, typeMeta.typ.Field(methodMeta.index).Offset)
fieldPtr := reflect.NewAt(fieldType, fieldAddr).Interface()
newWrapper := allocWrapper((*C.PyTypeObject)(unsafe.Pointer(pyType)), fieldPtr)
if newWrapper == nil {
SetError(fmt.Errorf("failed to allocate wrapper for nested struct"))
return nil
}
check(newWrapper != nil, "failed to allocate wrapper for nested struct")
return (*C.PyObject)(unsafe.Pointer(newWrapper))
}
}
Expand All @@ -162,29 +146,23 @@ func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.
func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.int) C.int {
maps := getGlobalData()
typeMeta := maps.typeMetas[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
if typeMeta == nil {
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
return -1
}
check(typeMeta != nil, fmt.Sprintf("type %v not registered", FromPy(self)))
methodMeta := typeMeta.methods[uint(methodId)]
if methodMeta == nil {
SetError(fmt.Errorf("setter method %d not found", methodId))
return -1
}
check(methodMeta != nil, fmt.Sprintf("setter method %d not found", methodId))

wrapper := (*wrapperType)(unsafe.Pointer(self))
goPtr := reflect.ValueOf(wrapper.goObj)
goValue := goPtr.Elem()

structValue := goValue
if !structValue.CanSet() {
SetError(fmt.Errorf("struct value cannot be set"))
SetTypeError(fmt.Errorf("struct value cannot be set"))
return -1
}

field := structValue.Field(methodMeta.index)
if !field.CanSet() {
SetError(fmt.Errorf("field %s cannot be set", methodMeta.name))
SetTypeError(fmt.Errorf("field %s cannot be set", methodMeta.name))
return -1
}

Expand All @@ -210,7 +188,7 @@ func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.i
}
valueWrapper := (*wrapperType)(unsafe.Pointer(value))
if valueWrapper == nil {
SetError(fmt.Errorf("invalid value for struct pointer field"))
SetTypeError(fmt.Errorf("invalid value for struct pointer field"))
return -1
}
field.Set(reflect.ValueOf(valueWrapper.goObj))
Expand All @@ -229,10 +207,6 @@ func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.i
return -1
}
valueWrapper := (*wrapperType)(unsafe.Pointer(value))
if valueWrapper == nil {
SetError(fmt.Errorf("invalid value for struct field"))
return -1
}
baseAddr := goPtr.UnsafePointer()
fieldAddr := unsafe.Add(baseAddr, typeMeta.typ.Field(methodMeta.index).Offset)
fieldPtr := reflect.NewAt(fieldType, fieldAddr)
Expand All @@ -257,21 +231,13 @@ func wrapperMethod(self, args *C.PyObject, methodId C.int) *C.PyObject {

maps := getGlobalData()
typeMeta, ok := maps.typeMetas[key]
if !ok {
SetError(fmt.Errorf("type %v not registered", FromPy(key)))
return nil
}
check(ok, fmt.Sprintf("type %v not registered", FromPy(key)))

methodMeta := typeMeta.methods[uint(methodId)]
return wrapperMethod_(typeMeta, methodMeta, self, args, methodId)
}

func wrapperMethod_(typeMeta *typeMeta, methodMeta *slotMeta, self, args *C.PyObject, methodId C.int) *C.PyObject {
if methodMeta == nil {
SetError(fmt.Errorf("method %d not found", methodId))
return nil
}

methodType := methodMeta.typ
argc := C.PyTuple_Size(args)
expectedArgs := methodType.NumIn()
Expand Down Expand Up @@ -550,9 +516,11 @@ func (m Module) AddType(obj, init any, name, doc string) Object {
*currentSlot = slot
}

typeName := fmt.Sprintf("%s.%s", m.Name(), name)

totalSize := unsafe.Sizeof(wrapperType{})
spec := &C.PyType_Spec{
name: C.CString(name),
name: C.CString(typeName),
basicsize: C.int(totalSize),
flags: C.Py_TPFLAGS_DEFAULT,
slots: slotsPtr,
Expand Down Expand Up @@ -627,30 +595,30 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
}

methodId := uint(len(meta.methods))
meta.methods[methodId] = &slotMeta{

methodPtr := C.wrapperMethods[methodId]
cName := C.CString(name)
cDoc := C.CString(doc)

def := (*C.PyMethodDef)(C.malloc(C.size_t(unsafe.Sizeof(C.PyMethodDef{}))))
def.ml_name = cName
def.ml_meth = C.PyCFunction(methodPtr)
def.ml_flags = C.METH_VARARGS
def.ml_doc = cDoc

methodMeta := &slotMeta{
name: name,
methodName: name,
fn: fn,
typ: t,
doc: doc,
hasRecv: false,
def: def,
}

methodPtr := C.wrapperMethods[methodId]
cName := C.CString(name)
cDoc := C.CString(doc)

def := &C.PyMethodDef{
ml_name: cName,
ml_meth: C.PyCFunction(methodPtr),
ml_flags: C.METH_VARARGS,
ml_doc: cDoc,
}
meta.methods[methodId] = methodMeta

pyFunc := C.PyCFunction_NewEx(def, m.obj, m.obj)
if pyFunc == nil {
panic(fmt.Sprintf("Failed to create function %s", name))
}
check(pyFunc != nil, fmt.Sprintf("Failed to create function %s", name))

if C.PyModule_AddObjectRef(m.obj, cName, pyFunc) < 0 {
C.Py_DecRef(pyFunc)
Expand All @@ -660,12 +628,6 @@ func (m Module) AddMethod(name string, fn any, doc string) Func {
return newFunc(pyFunc)
}

func SetError(err error) {
errStr := C.CString(err.Error())
C.PyErr_SetString(C.PyExc_RuntimeError, errStr)
C.free(unsafe.Pointer(errStr))
}

func SetTypeError(err error) {
errStr := C.CString(err.Error())
C.PyErr_SetString(C.PyExc_TypeError, errStr)
Expand Down
40 changes: 27 additions & 13 deletions global_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"sync"
"sync/atomic"
"unsafe"
)

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -51,11 +52,15 @@ func (l *decRefList) add(obj *C.PyObject) {
l.mu.Unlock()
}

func (l *decRefList) decRefAll() {
var list []*C.PyObject
func (l *decRefList) len() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.objects)
}

func (l *decRefList) decRefAll() {
l.mu.Lock()
list = l.objects
list := l.objects
l.objects = make([]*C.PyObject, 0, maxPyObjects*2)
l.mu.Unlock()

Expand All @@ -67,12 +72,12 @@ func (l *decRefList) decRefAll() {
// ----------------------------------------------------------------------------

type globalData struct {
typeMetas map[*C.PyObject]*typeMeta
pyTypes map[reflect.Type]*C.PyObject
holders holderList
decRefList decRefList
disableDecRef bool
finished int32
typeMetas map[*C.PyObject]*typeMeta
pyTypes map[reflect.Type]*C.PyObject
holders holderList
decRefList decRefList
finished int32
alwaysDecRef bool
}

var (
Expand All @@ -84,17 +89,16 @@ func getGlobalData() *globalData {
}

func (gd *globalData) addDecRef(obj *C.PyObject) {
if gd.disableDecRef {
return
}
if atomic.LoadInt32(&gd.finished) != 0 {
return
}
gd.decRefList.add(obj)
}

func (gd *globalData) decRefObjectsIfNeeded() {
gd.decRefList.decRefAll()
if gd.alwaysDecRef || gd.decRefList.len() >= maxPyObjects {
gd.decRefList.decRefAll()
}
}

// ----------------------------------------------------------------------------
Expand All @@ -111,5 +115,15 @@ func markFinished() {
}

func cleanupGlobal() {
for _, meta := range global.typeMetas {
for _, method := range meta.methods {
def := method.def
if def != nil {
C.free(unsafe.Pointer(def.ml_name))
C.free(unsafe.Pointer(def.ml_doc))
C.free(unsafe.Pointer(def))
}
}
}
global = nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/cpunion/go-python

go 1.23
go 1.20
Loading