Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Bug in Optimizer's serializeState and deserializeState methods (Scala) #14265

Closed
satyakrishnagorti opened this issue Feb 27, 2019 · 4 comments
Closed

Comments

@satyakrishnagorti
Copy link
Contributor

Description

Currently there is a bug in the way Optimizer is trying to serialize state which fails when trying to deserialize Optimizer that has no states (like SGD without momentum).

Issue

Currently the way serialize is being done is as below: (pasting Optimizer.serailizeState())

  override def serializeState(): Array[Byte] = {
        val bos = new ByteArrayOutputStream()
        try {
          val out = new ObjectOutputStream(bos)
          out.writeInt(states.size)
          states.foreach { case (k, v) =>
            if (v != null) {
              out.writeInt(k)
              val stateBytes = optimizer.serializeState(v)
              if (stateBytes == null) {
                out.writeInt(0)
              } else {
                out.writeInt(stateBytes.length)
                out.write(stateBytes)
              }
            }
          }
          out.flush()
          bos.toByteArray
        } finally {
         ...
      }
  }

When an Optimizer without states like SGD with momentum set as 0 is being used. The states map (Map[Int, AnyRef]) contains a (key, value) pair as (some integer index, null).

The above serialize method does not write k as the value of key and 0 as the value of stateBytes, due to the null check if (v != null)

Now while deserializing: (Pasting code from Optimizer.deserializeState())

  override def deserializeState(bytes: Array[Byte]): Unit = {
        val bis = new ByteArrayInputStream(bytes)
        var in: ObjectInputStream = null
        try {
          in = new ObjectInputStream(bis)
          val size = in.readInt()
          (0 until size).foreach(_ => {
            val key = in.readInt()
            val bytesLength = in.readInt()
            val value =
              if (bytesLength > 0) {
                val bytes = Array.fill[Byte](bytesLength)(0)
                in.readFully(bytes)
                optimizer.deserializeState(bytes)
              } else {
                null
              }
            states.update(key, value)
          })
        } finally {
          ...
      }
  }

In the foreach loop, the key is being read (which wasn't serialized previously) hence, this would cause an java.io.EOFException.

Solution.

Get rid of if (v != null) check and retain the rest.

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended labels: Scala, Bug

@leleamol
Copy link
Contributor

@mxnet-label-bot add [Scala, Bug]

@zachgk
Copy link
Contributor

zachgk commented Mar 1, 2019

Hey @satyakrishnagorti, it sounds like you have looked at this and know what to change. If you are interested, you are also more than welcome to open a pull request and contribute the fixes yourself. Let me know if you need any help

@satyakrishnagorti
Copy link
Contributor Author

Yeah, I will send a PR soon.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

5 participants