From f12999295d9001c9a42e861e880080597546bbdd Mon Sep 17 00:00:00 2001 From: io flament Date: Fri, 12 Jan 2018 19:24:38 -0300 Subject: [PATCH] adding missing modules to DEC (#9407) autoencoder, model, solver and data modules --- .../deep-embedded-clustering/autoencoder.py | 206 ++++++++++++++++++ example/deep-embedded-clustering/data.py | 40 ++++ example/deep-embedded-clustering/model.py | 78 +++++++ example/deep-embedded-clustering/solver.py | 151 +++++++++++++ 4 files changed, 475 insertions(+) create mode 100644 example/deep-embedded-clustering/autoencoder.py create mode 100644 example/deep-embedded-clustering/data.py create mode 100644 example/deep-embedded-clustering/model.py create mode 100644 example/deep-embedded-clustering/solver.py diff --git a/example/deep-embedded-clustering/autoencoder.py b/example/deep-embedded-clustering/autoencoder.py new file mode 100644 index 000000000000..096f04529c3b --- /dev/null +++ b/example/deep-embedded-clustering/autoencoder.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring, arguments-differ +from __future__ import print_function + +import logging + +import mxnet as mx +import numpy as np +import model +from solver import Solver, Monitor + + +class AutoEncoderModel(model.MXModel): + def setup(self, dims, sparseness_penalty=None, pt_dropout=None, + ft_dropout=None, input_act=None, internal_act='relu', output_act=None): + self.N = len(dims) - 1 + self.dims = dims + self.stacks = [] + self.pt_dropout = pt_dropout + self.ft_dropout = ft_dropout + self.input_act = input_act + self.internal_act = internal_act + self.output_act = output_act + + self.data = mx.symbol.Variable('data') + for i in range(self.N): + if i == 0: + decoder_act = input_act + idropout = None + else: + decoder_act = internal_act + idropout = pt_dropout + if i == self.N-1: + encoder_act = output_act + odropout = None + else: + encoder_act = internal_act + odropout = pt_dropout + istack, iargs, iargs_grad, iargs_mult, iauxs = self.make_stack( + i, self.data, dims[i], dims[i+1], sparseness_penalty, + idropout, odropout, encoder_act, decoder_act + ) + self.stacks.append(istack) + self.args.update(iargs) + self.args_grad.update(iargs_grad) + self.args_mult.update(iargs_mult) + self.auxs.update(iauxs) + self.encoder, self.internals = self.make_encoder( + self.data, dims, sparseness_penalty, ft_dropout, internal_act, output_act) + self.decoder = self.make_decoder( + self.encoder, dims, sparseness_penalty, ft_dropout, internal_act, input_act) + if input_act == 'softmax': + self.loss = self.decoder + else: + self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data) + + def make_stack(self, istack, data, num_input, num_hidden, sparseness_penalty=None, + idropout=None, odropout=None, encoder_act='relu', decoder_act='relu'): + x = data + if idropout: + x = mx.symbol.Dropout(data=x, p=idropout) + x = mx.symbol.FullyConnected(name='encoder_%d'%istack, data=x, num_hidden=num_hidden) + if encoder_act: + x = mx.symbol.Activation(data=x, act_type=encoder_act) + if encoder_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % istack, penalty=sparseness_penalty) + if odropout: + x = mx.symbol.Dropout(data=x, p=odropout) + x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input) + if decoder_act == 'softmax': + x = mx.symbol.Softmax(data=x, label=data, prob_label=True, act_type=decoder_act) + elif decoder_act: + x = mx.symbol.Activation(data=x, act_type=decoder_act) + if decoder_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % istack, penalty=sparseness_penalty) + x = mx.symbol.LinearRegressionOutput(data=x, label=data) + else: + x = mx.symbol.LinearRegressionOutput(data=x, label=data) + + args = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), + 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), + 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), + 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} + args_grad = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), + 'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), + 'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), + 'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} + args_mult = {'encoder_%d_weight'%istack: 1.0, + 'encoder_%d_bias'%istack: 2.0, + 'decoder_%d_weight'%istack: 1.0, + 'decoder_%d_bias'%istack: 2.0,} + auxs = {} + if encoder_act == 'sigmoid' and sparseness_penalty: + auxs['sparse_encoder_%d_moving_avg' % istack] = mx.nd.ones(num_hidden, self.xpu) * 0.5 + if decoder_act == 'sigmoid' and sparseness_penalty: + auxs['sparse_decoder_%d_moving_avg' % istack] = mx.nd.ones(num_input, self.xpu) * 0.5 + init = mx.initializer.Uniform(0.07) + for k, v in args.items(): + init(mx.initializer.InitDesc(k), v) + + return x, args, args_grad, args_mult, auxs + + def make_encoder(self, data, dims, sparseness_penalty=None, dropout=None, internal_act='relu', + output_act=None): + x = data + internals = [] + N = len(dims) - 1 + for i in range(N): + x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1]) + if internal_act and i < N-1: + x = mx.symbol.Activation(data=x, act_type=internal_act) + if internal_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) + elif output_act and i == N-1: + x = mx.symbol.Activation(data=x, act_type=output_act) + if output_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty) + if dropout: + x = mx.symbol.Dropout(data=x, p=dropout) + internals.append(x) + return x, internals + + def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, + internal_act='relu', input_act=None): + x = feature + N = len(dims) - 1 + for i in reversed(range(N)): + x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i]) + if internal_act and i > 0: + x = mx.symbol.Activation(data=x, act_type=internal_act) + if internal_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) + elif input_act and i == 0: + x = mx.symbol.Activation(data=x, act_type=input_act) + if input_act == 'sigmoid' and sparseness_penalty: + x = mx.symbol.IdentityAttachKLSparseReg( + data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty) + if dropout and i > 0: + x = mx.symbol.Dropout(data=x, p=dropout) + return x + + def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, + lr_scheduler=None, print_every=1000): + def l2_norm(label, pred): + return np.mean(np.square(label-pred))/2.0 + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, + lr_scheduler=lr_scheduler) + solver.set_metric(mx.metric.CustomMetric(l2_norm)) + solver.set_monitor(Monitor(print_every)) + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, + last_batch_handle='roll_over') + for i in range(self.N): + if i == 0: + data_iter_i = data_iter + else: + X_i = list(model.extract_feature( + self.internals[i-1], self.args, self.auxs, data_iter, X.shape[0], + self.xpu).values())[0] + data_iter_i = mx.io.NDArrayIter({'data': X_i}, batch_size=batch_size, + last_batch_handle='roll_over') + logging.info('Pre-training layer %d...', i) + solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, self.auxs, + data_iter_i, 0, n_iter, {}, False) + + def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None, + print_every=1000): + def l2_norm(label, pred): + return np.mean(np.square(label-pred))/2.0 + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, + lr_scheduler=lr_scheduler) + solver.set_metric(mx.metric.CustomMetric(l2_norm)) + solver.set_monitor(Monitor(print_every)) + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, + last_batch_handle='roll_over') + logging.info('Fine tuning...') + solver.solve(self.xpu, self.loss, self.args, self.args_grad, self.auxs, data_iter, + 0, n_iter, {}, False) + + def eval(self, X): + batch_size = 100 + data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False, + last_batch_handle='pad') + Y = list(model.extract_feature( + self.loss, self.args, self.auxs, data_iter, X.shape[0], self.xpu).values())[0] + return np.mean(np.square(Y-X))/2.0 \ No newline at end of file diff --git a/example/deep-embedded-clustering/data.py b/example/deep-embedded-clustering/data.py new file mode 100644 index 000000000000..9fd472e6a8b1 --- /dev/null +++ b/example/deep-embedded-clustering/data.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import os +import numpy as np +from sklearn.datasets import fetch_mldata + + +def get_mnist(): + """ Gets MNIST dataset """ + + np.random.seed(1234) # set seed for deterministic ordering + data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + data_path = os.path.join(data_path, '../../data') + mnist = fetch_mldata('MNIST original', data_home=data_path) + p = np.random.permutation(mnist.data.shape[0]) + X = mnist.data[p].astype(np.float32)*0.02 + Y = mnist.target[p] + return X, Y + + + + diff --git a/example/deep-embedded-clustering/model.py b/example/deep-embedded-clustering/model.py new file mode 100644 index 000000000000..777634e3cf88 --- /dev/null +++ b/example/deep-embedded-clustering/model.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import mxnet as mx +import numpy as np +try: + import cPickle as pickle +except ModuleNotFoundError: + import pickle + + +def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()): + input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in data_iter.provide_data] + input_names = [k for k, shape in data_iter.provide_data] + args = dict(args, **dict(zip(input_names, input_buffs))) + exe = sym.bind(xpu, args=args, aux_states=auxs) + outputs = [[] for _ in exe.outputs] + output_buffs = None + + data_iter.hard_reset() + for batch in data_iter: + for data, buff in zip(batch.data, input_buffs): + data.copyto(buff) + exe.forward(is_train=False) + if output_buffs is None: + output_buffs = [mx.nd.empty(i.shape, ctx=mx.cpu()) for i in exe.outputs] + else: + for out, buff in zip(outputs, output_buffs): + out.append(buff.asnumpy()) + for out, buff in zip(exe.outputs, output_buffs): + out.copyto(buff) + for out, buff in zip(outputs, output_buffs): + out.append(buff.asnumpy()) + outputs = [np.concatenate(i, axis=0)[:N] for i in outputs] + return dict(zip(sym.list_outputs(), outputs)) + + +class MXModel(object): + def __init__(self, xpu=mx.cpu(), *args, **kwargs): + self.xpu = xpu + self.loss = None + self.args = {} + self.args_grad = {} + self.args_mult = {} + self.auxs = {} + self.setup(*args, **kwargs) + + def save(self, fname): + args_save = {key: v.asnumpy() for key, v in self.args.items()} + with open(fname, 'wb') as fout: + pickle.dump(args_save, fout) + + def load(self, fname): + with open(fname, 'rb') as fin: + args_save = pickle.load(fin) + for key, v in args_save.items(): + if key in self.args: + self.args[key][:] = v + + def setup(self, *args, **kwargs): + raise NotImplementedError("must override this") \ No newline at end of file diff --git a/example/deep-embedded-clustering/solver.py b/example/deep-embedded-clustering/solver.py new file mode 100644 index 000000000000..567c78eeb06c --- /dev/null +++ b/example/deep-embedded-clustering/solver.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +from __future__ import print_function + +import logging + +import mxnet as mx +import numpy as np + + +class Monitor(object): + def __init__(self, interval, level=logging.DEBUG, stat=None): + self.interval = interval + self.level = level + if stat is None: + def mean_abs(x): + return np.fabs(x).mean() + self.stat = mean_abs + else: + self.stat = stat + + def forward_end(self, i, internals): + if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): + for key in sorted(internals.keys()): + arr = internals[key] + logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s', + i, key, self.stat.__name__, str(self.stat(arr.asnumpy()))) + + def backward_end(self, i, weights, grads, metric=None): + if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level): + for key in sorted(grads.keys()): + arr = grads[key] + logging.log(self.level, 'Iter:%d param:%s\t\tstat(%s):%s\t\tgrad_stat:%s', + i, key, self.stat.__name__, + str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy()))) + if i % self.interval == 0 and metric is not None: + logging.log(logging.INFO, 'Iter:%d metric:%f', i, metric.get()[1]) + metric.reset() + + +class Solver(object): + def __init__(self, optimizer, **kwargs): + if isinstance(optimizer, str): + self.optimizer = mx.optimizer.create(optimizer, **kwargs) + else: + self.optimizer = optimizer + self.updater = mx.optimizer.get_updater(self.optimizer) + self.monitor = None + self.metric = None + self.iter_end_callback = None + self.iter_start_callback = None + + def set_metric(self, metric): + self.metric = metric + + def set_monitor(self, monitor): + self.monitor = monitor + + def set_iter_end_callback(self, callback): + self.iter_end_callback = callback + + def set_iter_start_callback(self, callback): + self.iter_start_callback = callback + + def solve(self, xpu, sym, args, args_grad, auxs, + data_iter, begin_iter, end_iter, args_lrmult=None, debug=False): + if args_lrmult is None: + args_lrmult = dict() + input_desc = data_iter.provide_data + data_iter.provide_label + input_names = [k for k, shape in input_desc] + input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in input_desc] + args = dict(args, **dict(zip(input_names, input_buffs))) + + output_names = sym.list_outputs() + if debug: + sym_group = [] + for x in sym.get_internals(): + if x.name not in args: + if x.name not in output_names: + x = mx.symbol.BlockGrad(x, name=x.name) + sym_group.append(x) + sym = mx.symbol.Group(sym_group) + exe = sym.bind(xpu, args=args, args_grad=args_grad, aux_states=auxs) + + assert len(sym.list_arguments()) == len(exe.grad_arrays) + update_dict = { + name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None + } + batch_size = input_buffs[0].shape[0] + self.optimizer.rescale_grad = 1.0/batch_size + self.optimizer.set_lr_mult(args_lrmult) + + output_dict = {} + output_buff = {} + internal_dict = dict(zip(input_names, input_buffs)) + for key, arr in zip(sym.list_outputs(), exe.outputs): + if key in output_names: + output_dict[key] = arr + output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) + else: + internal_dict[key] = arr + + data_iter.reset() + for i in range(begin_iter, end_iter): + if self.iter_start_callback is not None: + if self.iter_start_callback(i): + return + try: + batch = data_iter.next() + except StopIteration: + data_iter.reset() + batch = data_iter.next() + for data, buff in zip(batch.data+batch.label, input_buffs): + data.copyto(buff) + exe.forward(is_train=True) + if self.monitor is not None: + self.monitor.forward_end(i, internal_dict) + for key in output_dict: + output_dict[key].copyto(output_buff[key]) + + exe.backward() + for key, arr in update_dict.items(): + self.updater(key, arr, args[key]) + + if self.metric is not None: + self.metric.update([input_buffs[-1]], + [output_buff[output_names[0]]]) + + if self.monitor is not None: + self.monitor.backward_end(i, args, update_dict, self.metric) + + if self.iter_end_callback is not None: + if self.iter_end_callback(i): + return + exe.outputs[0].wait_to_read() \ No newline at end of file