sm.go 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. package std
  2. import (
  3. "github.com/kpmy/ypk/assert"
  4. "gone/perc"
  5. "math/rand"
  6. )
  7. func New(nn, no int, w_fn func() int) (ret *model.Layer) {
  8. ret = &model.Layer{}
  9. for i := 0; i < nn; i++ {
  10. next := model.NewNode()
  11. next.Out = make([]interface{}, rand.Intn(no)+1)
  12. next.Weights = make([]int, len(next.Out))
  13. if w_fn != nil {
  14. for i := 0; i < len(next.Weights); i++ {
  15. next.Weights[i] = w_fn()
  16. }
  17. }
  18. ret.Nodes = append(ret.Nodes, next)
  19. }
  20. return
  21. }
  22. func Join(in *model.Layer, out *model.Layer) {
  23. j := 0
  24. type im map[int]interface{}
  25. cache := make([]im, len(out.Nodes))
  26. for k, n := range in.Nodes {
  27. for i := 0; i < len(n.Out); {
  28. assert.For(len(n.Out) <= len(out.Nodes), 20, len(n.Out), len(out.Nodes))
  29. l := model.Link{NodeId: k, LinkId: i}
  30. skip := false
  31. if cache[j] != nil {
  32. if _, ok := cache[j][k]; ok {
  33. skip = true
  34. j++
  35. if j == len(cache) {
  36. j = 0
  37. }
  38. }
  39. }
  40. if !skip {
  41. out.Nodes[j].In[l] = nil
  42. if cache[j] == nil {
  43. cache[j] = make(im)
  44. }
  45. cache[j][k] = l
  46. //log.Println(l, "to", j)
  47. i++
  48. }
  49. //log.Println(k, len(n.Out), i, j)
  50. }
  51. }
  52. in.Next = out
  53. }