Skip to content

Commit

Permalink
fixed docs
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Feb 12, 2024
1 parent 97252ba commit b97630f
Show file tree
Hide file tree
Showing 9 changed files with 2 additions and 2 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions tests/test_sklearn_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_linear_regression(self):
x_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
y_data = jnp.dot(x_data, jnp.array([[2.0]])) - jnp.array([[-1.0]])
lr_model = LinearRegression(input_dim, output_dim)
lr_model.train(x_data, y_data)
lr_model.fit(x_data, y_data)
learned_weights, learned_bias = lr_model.get_params()
self.assertTrue(jnp.allclose(learned_weights, jnp.array([[2.0]]), atol=1e-1))
self.assertTrue(jnp.allclose(learned_bias, jnp.array([[1.0]]), atol=1e-1))
Expand All @@ -88,7 +88,7 @@ def test_logistic_regression(self):
logits = jnp.dot(x_data, jnp.array([0.5, -0.5])) - 0.1
y_data = (logits > 0).astype(jnp.float32)
lr_model = LogisticRegression(input_dim)
lr_model.train(x_data, y_data)
lr_model.fit(x_data, y_data)
test_data = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim))
predictions = lr_model.predict(test_data)
self.assertTrue(jnp.all(predictions >= 0) and jnp.all(predictions <= 1))
Expand Down

0 comments on commit b97630f

Please sign in to comment.