-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataLoader.py
33 lines (27 loc) · 956 Bytes
/
DataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import matplotlib.pyplot as plt
import torch
class SineWaveTask:
def __init__(self):
self.a = np.random.uniform(0.1, 5.0)
self.b = np.random.uniform(0, 2*np.pi)
self.train_x = None
def f(self, x):
return self.a * np.sin(x + self.b)
def training_set(self, size=10, force_new=False):
if self.train_x is None and not force_new:
self.train_x = np.random.uniform(-5, 5, size)
x = self.train_x
elif not force_new:
x = self.train_x
else:
x = np.random.uniform(-5, 5, size)
y = self.f(x)
return torch.Tensor(x), torch.Tensor(y)
def test_set(self, size=50):
x = np.linspace(-5, 5, size)
y = self.f(x)
return torch.Tensor(x), torch.Tensor(y)
def plot(self, *args, **kwargs):
x, y = self.test_set(size=100)
return plt.plot(x.numpy(), y.numpy(), *args, **kwargs)