-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from dvgodoy/weights
Weights
- Loading branch information
Showing
34 changed files
with
701 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
|
||
def load_data(n_dims=10, n_points=1000, classif_radius_fraction=0.5, only_sphere=False, shuffle=True, seed=13): | ||
""" | ||
Parameters | ||
---------- | ||
n_dims: int, optional | ||
Number of dimensions of the n-ball. Default is 10. | ||
n_points: int, optional | ||
Number of points in each parabola. Default is 1,000. | ||
classif_radius_fraction: float, optional | ||
Points farther away from the center than | ||
`classification_radius_fraction * ball radius` are | ||
considered to be positive cases. The remaining | ||
points are the negative cases. | ||
only_sphere: boolean | ||
If True, generates a n-sphere, that is, a hollow n-ball. | ||
Default is False. | ||
shuffle: boolean, optional | ||
If True, the points are shuffled. Default is True. | ||
seed: int, optional | ||
Random seed. Default is 13. | ||
Returns | ||
------- | ||
X, y: tuple of ndarray | ||
X is an array of shape (n_points, n_dims) containing the | ||
points in the n-ball. | ||
y is an array of shape (n_points, 1) containing the | ||
classes of the samples. | ||
""" | ||
radius = np.sqrt(n_dims) | ||
points = np.random.normal(size=(n_points, n_dims)) | ||
sphere = radius * points / np.linalg.norm(points, axis=1).reshape(-1, 1) | ||
if only_sphere: | ||
X = sphere | ||
else: | ||
X = sphere * np.random.uniform(size=(n_points, 1))**(1 / n_dims) | ||
|
||
adjustment = 1 / np.std(X) | ||
radius *= adjustment | ||
X *= adjustment | ||
|
||
y = (np.abs(np.sum(X, axis=1)) > (radius * classif_radius_fraction)).astype(np.int) | ||
|
||
# But we must not feed the network with neatly organized inputs... | ||
# so let's randomize them | ||
if shuffle: | ||
np.random.seed(seed) | ||
shuffled = np.random.permutation(range(X.shape[0])) | ||
X = X[shuffled] | ||
y = y[shuffled].reshape(-1, 1) | ||
|
||
return (X, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import itertools | ||
import numpy as np | ||
|
||
def load_data(n_dims=10, vertices=(-1., 1.), shuffle=True, seed=13): | ||
""" | ||
Parameters | ||
---------- | ||
n_dims: int, optional | ||
Number of dimensions of the hypercube. Default is 10. | ||
edge: tuple of floats, optional | ||
Two vertices of an edge. Default is (-1., 1.). | ||
shuffle: boolean, optional | ||
If True, the points are shuffled. Default is True. | ||
seed: int, optional | ||
Random seed. Default is 13. | ||
Returns | ||
------- | ||
X, y: tuple of ndarray | ||
X is an array of shape (2 ** n_dims, n_dims) containing the | ||
vertices coordinates of the hypercube. | ||
y is an array of shape (2 ** n_dims, 1) containing the | ||
classes of the samples. | ||
""" | ||
X = np.array(list(itertools.product(vertices, repeat=n_dims))) | ||
y = (np.sum(np.clip(X, a_min=0, a_max=1), axis=1) >= (n_dims / 2.0)).astype(np.int) | ||
|
||
# But we must not feed the network with neatly organized inputs... | ||
# so let's randomize them | ||
if shuffle: | ||
np.random.seed(seed) | ||
shuffled = np.random.permutation(range(X.shape[0])) | ||
X = X[shuffled] | ||
y = y[shuffled].reshape(-1, 1) | ||
|
||
return (X, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.