Thanks for showing interest in this, this made me fix some minor problems and add a few more forward calls, in case you have models with more than one input (language transformers usually have a mask for example).
To have basic learning functionality (backpropagation utilizing the gradient) you'd need more tensor functions, meaning slicing/view/add/multiply/divide/... and the loss functions would be awesome as well. However almost everyone does that in Python, as the syntax allows straight forward tensor usage by operator overloading and so on.
The whole machine learning thing is reducable to numerical optimization problems. Imagine you are standing on a landscape and you want to find the deepest point on that. What you would do is start from an arbitrary point and let something roll down towards the deepest point until it stands still. You cannot be sure that there's no deeper point on the whole landscape but you reached a local minimum. The process of letting something roll down the landscape is called gradient descent. You only find local optimas this way, no global optimas, but the high amount of dimensions/parameters of some problems often makes it difficult to find a global optima anyway.
Frameworks like PyTorch and Tensorflow are supporting you at finding the next local optimum. The basic elements you always work with are tensors. You add them, multiply them, slice them, ... and the frameworks keep track of the gradient for you (autograd). This means you convert the input image to a tensor object, push it through some neural network calculations (usually matrix multiplications of the input with weights etc.), calculate the loss (usually the difference to the expected output), push the gradient back along the calclations again and the optimizer does a step with the learning rate as step length on a specific subset of weights/parameters inside the chain of calculations.
You can see exactly this in train.py (I added some comments):
Code: Select all
def train(model, dataset):
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=8,
shuffle=True,
num_workers=8)
loss_fn = torch.nn.MSELoss() # Mean Square Error as loss function
optim = torch.optim.AdamW(model.parameters(), lr=0.0001) # optimize only the model's weights with the given learning rate
for epoch in range(3):
loss_sum = 0.0
count = 0
model.train()
for inputs, outputs in data_loader:
batch_size = inputs.size(0)
outputs = torch.nn.functional.one_hot(outputs, 10).to(dtype=torch.float)
result = model(inputs)
loss = loss_fn(result, outputs) # calculate the loss
loss_sum += float(loss)
count += batch_size
loss.backward() # push the gradient back along all calculations
optim.step() # do a step and update the weights
if count >= 10000:
break
print(f"Epoch {epoch} finished with loss {loss_sum / count}")
model.eval()
torch.jit.save(torch.jit.script(model), f"epoch_{epoch}.pt")
You can also see the whole thing as a complex construct of springs that are connected and under tension. As soon as you let the springs go they will jump around and try to relax and reduce the tension until the whole system stands still.