Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into predict_cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Jan 6, 2017
2 parents 9785294 + 1ae2905 commit d6005b1
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ R-package/inst/*
*.tar.gz
*.tgz
R-package/man/*.Rd
R-package/R/mxnet_generated.R

# data
*.rec
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,10 @@ rpkg:
echo "import(methods)" >> R-package/NAMESPACE
R CMD INSTALL R-package
Rscript -e "require(mxnet); mxnet:::mxnet.export(\"R-package\")"
rm -rf R-package/NAMESPACE
Rscript -e "require(roxygen2); roxygen2::roxygenise(\"R-package\")"
R CMD build --no-build-vignettes R-package
rm -rf mxnet_current_r.tar.gz
mv mxnet_*.tar.gz mxnet_current_r.tar.gz

scalapkg:
Expand Down
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ BugReports: https://github.com/dmlc/mxnet/issues
Imports:
methods,
Rcpp (>= 0.12.1),
DiagrammeR (>= 0.8.1),
DiagrammeR (>= 0.9.0),
data.table,
jsonlite,
magrittr,
Expand Down
26 changes: 14 additions & 12 deletions R-package/R/viz.graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
#' @importFrom data.table :=
#' @importFrom data.table setkey
#' @importFrom jsonlite fromJSON
#' @importFrom DiagrammeR create_nodes
#' @importFrom DiagrammeR create_node_df
#' @importFrom DiagrammeR create_graph
#' @importFrom DiagrammeR create_edges
#' @importFrom DiagrammeR combine_edges
#' @importFrom DiagrammeR create_edge_df
#' @importFrom DiagrammeR combine_edfs
#' @importFrom DiagrammeR render_graph
#'
#' @param model a \code{string} representing the path to a file containing the \code{JSon} of a model dump or the actual model dump.
Expand Down Expand Up @@ -106,8 +106,8 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font
mx.model.nodes[,id] %>% unique %>% setdiff(nodes.to.keep) %>% sort

nodes <-
create_nodes(
nodes = mx.model.nodes[id %in% nodes.to.keep, id],
create_node_df(
n = length(mx.model.nodes[id %in% nodes.to.keep, id]),
label = mx.model.nodes[id %in% nodes.to.keep, label],
type = "lower",
style = "filled",
Expand All @@ -118,6 +118,8 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font
width = "1.3",
height = "0.8034"
)

nodes$id <- mx.model.nodes[id %in% nodes.to.keep, id]

mx.model.nodes[,has.connection:= sapply(inputs, function(x)
length(x) > 0)]
Expand All @@ -132,24 +134,24 @@ graph.viz <- function(model, graph.title = "Computation graph", graph.title.font
origin <-
nodes.to.insert[i, inputs][[1]][,1] %>% setdiff(nodes.to.remove) %>% unique
destination <- rep(current.id, length(origin))
edges.temp <- create_edges(from = origin,
to = destination,
edges.temp <- create_edge_df(from = as.character(origin),
to = as.character(destination),
relationship = "leading_to")
if (is.null(edges))
edges <- edges.temp
else
edges <- combine_edges(edges.temp, edges)
edges <- combine_edfs(edges.temp, edges)
}

graph <-
create_graph(
nodes_df = nodes,
edges_df = edges,
directed = TRUE,
edges_df = edges#,
# directed = TRUE#,
# node_attrs = c("fontname = Helvetica"),
graph_attrs = paste0("label = \"", graph.title, "\"") %>% c(paste0("fontname = ", graph.title.font.name)) %>% c(paste0("fontsize = ", graph.title.font.size)) %>% c("labelloc = t"),
# graph_attrs = paste0("label = \"", graph.title, "\"") %>% c(paste0("fontname = ", graph.title.font.name)) %>% c(paste0("fontsize = ", graph.title.font.size)) %>% c("labelloc = t"),
# node_attrs = "fontname = Helvetica",
edge_attrs = c("color = gray20", "arrowsize = 0.8", "arrowhead = vee")
# edge_attrs = c("color = gray20", "arrowsize = 0.8", "arrowhead = vee")
)

return(render_graph(graph, width = graph.width.px, height = graph.height.px))
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,13 @@ def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
Whether to save optimizer states for continue training
"""
self._symbol.save('%s-symbol.json'%prefix)
self.save_params('%s-%04d.params'%(prefix, epoch))
param_name = '%s-%04d.params' % (prefix, epoch)
self.save_params(param_name)
logging.info('Saved checkpoint to \"%s\"', param_name)
if save_optimizer_states:
self.save_optimizer_states('%s-%04d.states'%(prefix, epoch))
state_name = '%s-%04d.states' % (prefix, epoch)
self.save_optimizer_states(state_name)
logging.info('Saved optimizer state to \"%s\"', state_name)

def _reset_bind(self):
"""Internal function to reset binded state."""
Expand Down

0 comments on commit d6005b1

Please sign in to comment.