All it should do for now is attempt to regurgitate the training input, which only serves to provide a proof that it works.
next steps train it on a source file to see if it can predict anything useful
Change it from using single chars to words, you would use a trie to do the char to word then feed into RNN
The goal with training is to get the loss to converge as close to 0.0, though you will see it jump up sometimes
Code: Select all
;RNN Character-Level Prediction
EnableExplicit
RandomSeed(3)
Structure RNN
NumInputs.i ; Vocabulary size (number of unique characters)
NumHidden.i
NumOutputs.i ; Same as vocabulary size
LearningRate.d
;Data
Array inputs.d(0)
Array hidden.d(0)
Array hiddenPrev.d(0)
Array outputs.d(0)
;Weights
Array Wih.d(0,0)
Array Whh.d(0,0)
Array Who.d(0,0)
;Biases
Array bh.d(0)
Array bo.d(0)
;Vocabulary mapping
Map charToIndex.i() ; Character -> Index
Array indexToChar.s(0) ; Index -> Character
vocabSize.i
EndStructure
Procedure RNN_New()
Protected *rnn.RNN = AllocateStructure(RNN)
ProcedureReturn *rnn
EndProcedure
Procedure RNN_Init(*rnn.RNN,vocabSize, NumHidden, LearningRate.d=0.01)
*rnn\NumInputs = vocabSize
*rnn\NumHidden = NumHidden
*rnn\NumOutputs = vocabSize ; Output for each possible character
*rnn\vocabSize = vocabSize
*rnn\LearningRate = LearningRate
Dim *rnn\inputs.d(vocabSize)
Dim *rnn\hidden.d(NumHidden)
Dim *rnn\hiddenPrev.d(NumHidden)
Dim *rnn\outputs.d(vocabSize)
Dim *rnn\Wih.d(NumHidden, vocabSize) ; Input to Hidden
Dim *rnn\Whh.d(NumHidden, NumHidden) ; Hidden to Hidden recurrent
Dim *rnn\Who.d(vocabSize, NumHidden) ; Hidden to Output
;Biases
Dim *rnn\bh.d(NumHidden)
Dim *rnn\bo.d(vocabSize)
Protected i, j,scale.d
; Xavier initialization: scale by sqrt(1/n) where n is input size
; Initialize weights centered around zero
scale = Sqr(1.0 / vocabSize)
For i = 0 To NumHidden - 1
For j = 0 To vocabSize - 1
*rnn\Wih(i,j) = (Random($ffffffff) - $7FFFFFFF) * 0.0000000002328 *scale
Next
*rnn\bh(i) = 0.0
Next
scale = Sqr(1.0 / NumHidden)
For i = 0 To NumHidden - 1
For j = 0 To NumHidden - 1
*rnn\Whh(i,j) = (Random($ffffffff) - $7FFFFFFF) * 0.0000000002328 *scale
Next
Next
For i = 0 To vocabSize - 1
For j = 0 To NumHidden - 1
*rnn\Who(i,j) = (Random($ffffffff) - $7FFFFFFF) * 0.0000000002328 *scale
Next
*rnn\bo(i) = 0.0
Next
ProcedureReturn *rnn
EndProcedure
; Build vocabulary from training text
Procedure RNN_BuildVocabulary(*rnn.RNN, text.s)
Protected i, idx, char.s
Protected length = Len(text)
; First pass: collect unique characters
For i = 1 To length
char = Mid(text, i, 1)
If Not FindMapElement(*rnn\charToIndex(), char)
*rnn\charToIndex(char) = 0
EndIf
Next
ReDim *rnn\indexToChar(MapSize(*rnn\charToIndex()))
idx = 0
ForEach *rnn\charToIndex()
*rnn\charToIndex() = idx
*rnn\indexToChar(idx) = MapKey(*rnn\charToIndex())
idx + 1
Next
Debug "Vocabulary size: " + Str(idx) + " unique characters"
ProcedureReturn idx
EndProcedure
; Convert character to one-hot vector
Procedure RNN_CharToOneHot(*rnn.RNN, char.s)
Protected i, idx
; Clear input vector
For i = 0 To *rnn\vocabSize - 1
*rnn\inputs(i) = 0.0
Next
; Set the bit for this character
If FindMapElement(*rnn\charToIndex(), char)
idx = *rnn\charToIndex()
*rnn\inputs(idx) = idx
Else
Debug "error"
EndIf
EndProcedure
;Activation functions
Procedure.d _TanH(x.d)
Protected ex.d = Exp(2.0 * x)
ProcedureReturn (ex - 1.0) / (ex + 1.0)
EndProcedure
Procedure.d TanhDerivative(x.d)
ProcedureReturn 1.0 - x * x
EndProcedure
; Softmax activation for output layer
Procedure Softmax(*rnn.RNN)
Protected i
Protected maxVal.d = -1000000000.0
Protected sum.d = 0.0
; Find max for numerical stability
For i = 0 To *rnn\vocabSize - 1
If *rnn\outputs(i) > maxVal
maxVal = *rnn\outputs(i)
EndIf
Next
; Compute exp and sum
For i = 0 To *rnn\vocabSize - 1
*rnn\outputs(i) = Exp(*rnn\outputs(i) - maxVal)
sum + *rnn\outputs(i)
Next
; Normalize
For i = 0 To *rnn\vocabSize - 1
*rnn\outputs(i) / sum
Next
EndProcedure
;Forward pass
Procedure RNN_Forward(*rnn.RNN, resetState = 0)
Protected i, j
Protected sum.f
Protected NumHidden = *rnn\NumHidden
Protected NumInputs = *rnn\NumInputs
Protected NumOutputs = *rnn\NumOutputs
;Reset hidden state
If resetState
For i = 0 To NumHidden - 1
*rnn\hiddenPrev(i) = 0.0
Next
EndIf
;Calculate hidden layer: h = tanh(Wih * x + Whh * h_prev + bh)
For i = 0 To NumHidden - 1
sum = *rnn\bh(i)
;Input contribution
For j = 0 To NumInputs - 1
sum + *rnn\Wih(i,j) * *rnn\inputs(j)
Next
;Recurrent contribution
For j = 0 To NumHidden - 1
sum + *rnn\Whh(i,j) * *rnn\hiddenPrev(j)
Next
*rnn\hidden(i) = TanH(sum)
Next
;Calculate output layer: y = Who * h + bo
For i = 0 To NumOutputs - 1
sum = *rnn\bo(i)
For j = 0 To NumHidden - 1
sum + *rnn\Who(i,j) * *rnn\hidden(j)
Next
*rnn\outputs(i) = sum
Next
;Apply softmax to get probabilities
Softmax(*rnn)
;Save hidden state for next step
CopyArray(*rnn\hidden(),*rnn\hiddenPrev())
EndProcedure
; Training step with cross-entropy loss
Procedure RNN_Train(*rnn.RNN, targetCharIndex.i)
Protected i, j
Protected NumOutputs = *rnn\NumOutputs
Protected NumHidden = *rnn\NumHidden
Protected NumInputs = *rnn\NumInputs
Protected Dim outputGrad.f(NumOutputs)
Protected Dim hiddenError.f(NumHidden)
; Calculate output gradient (for softmax + cross-entropy)
; Gradient = predicted - target (this is the derivative of cross-entropy loss w.r.t. logits after softmax)
For i = 0 To NumOutputs - 1
If i = targetCharIndex
outputGrad(i) = 1.0 - *rnn\outputs(i); Target is 1.0 for correct class
Else
outputGrad(i) = 0.0 - *rnn\outputs(i); Target is 0.0 for wrong classes
EndIf
Next
;Update output weights (gradient descent: w + lr * gradient)
For i = 0 To NumOutputs - 1
For j = 0 To NumHidden - 1
*rnn\Who(i,j) + *rnn\LearningRate * outputGrad(i) * *rnn\hidden(j)
Next
*rnn\bo(i) + *rnn\LearningRate * outputGrad(i)
Next
;Calculate hidden errors (backpropagate through hidden layer)
For i = 0 To NumHidden - 1
hiddenError(i) = 0.0
For j = 0 To NumOutputs - 1
hiddenError(i) + outputGrad(j) * *rnn\Who(j,i)
Next
hiddenError(i) * TanhDerivative(*rnn\hidden(i))
Next
;Update hidden weights (gradient descent with clipping)
Protected clipValue.f = 5.0 ; Gradient clipping threshold
For i = 0 To NumHidden - 1
;Clip hidden error To prevent exploding gradients
If hiddenError(i) > clipValue
hiddenError(i) = clipValue
*rnn\LearningRate / 2
Debug "clip"
ElseIf hiddenError(i) < -clipValue
hiddenError(i) = -clipValue
*rnn\LearningRate / 2
Debug "clip"
EndIf
For j = 0 To NumInputs - 1
*rnn\Wih(i,j) + *rnn\LearningRate * hiddenError(i) * *rnn\inputs(j)
Next
For j = 0 To NumHidden - 1
*rnn\Whh(i,j) + *rnn\LearningRate * hiddenError(i) * *rnn\hiddenPrev(j)
Next
*rnn\bh(i) + *rnn\LearningRate * hiddenError(i)
Next
EndProcedure
; Train on a text sequence
Procedure RNN_TrainOnText(*rnn.RNN, text.s, epochs.i)
Protected i, j, epoch
Protected length = Len(text)
Protected char.s, nextChar.s
Protected targetIdx
Protected totalLoss.f
Debug "Training on " + Str(length) + " characters for " + Str(epochs) + " epochs..."
For epoch = 1 To epochs
totalLoss = 0.0
;Reset hidden state at start of each epoch
For i = 0 To *rnn\NumHidden - 1
*rnn\hiddenPrev(i) = 0.0
Next
;Train on each character predicting the next
For i = 1 To length - 1
char = Mid(text, i, 1)
nextChar = Mid(text, i + 1, 1)
;Convert input character to one-hot
RNN_CharToOneHot(*rnn, char)
;Forward pass
RNN_Forward(*rnn, 0) ;Don't reset state
;Get target character index
If FindMapElement(*rnn\charToIndex(), nextChar)
targetIdx = *rnn\charToIndex()
;Calculate loss (cross-entropy): -log(p_correct)
Protected epsilon.f = 0.0000001
Protected prob.f = *rnn\outputs(targetIdx)
If prob < epsilon
prob = epsilon
EndIf
totalLoss - Log(prob)
RNN_Train(*rnn, targetIdx)
EndIf
Next
If epoch % 10 = 0 Or epoch = 1
Debug "Epoch " + Str(epoch) + " - Average Loss: " + StrF(totalLoss / length, 4)
EndIf
Next
EndProcedure
;Predict next character given a seed string
Procedure.s RNN_Predict(*rnn.RNN, seed.s, length.i = 50)
Protected i, j, maxIdx
Protected maxProb.f
Protected char.s
Protected result.s = seed
; Reset hidden state
For i = 0 To *rnn\NumHidden - 1
*rnn\hiddenPrev(i) = 0.0
Next
; Process seed text to build up hidden state
For i = 1 To Len(seed)
char = Mid(seed, i, 1)
RNN_CharToOneHot(*rnn, char)
RNN_Forward(*rnn, 0)
Next
;Generate new characters
For i = 1 To length
;Find character with highest probability
maxProb = -1.0
maxIdx = 0
For j = 0 To *rnn\vocabSize - 1
If *rnn\outputs(j) > maxProb
maxProb = *rnn\outputs(j)
maxIdx = j
EndIf
Next
; Add predicted character
char = *rnn\indexToChar(maxIdx)
result + char
;Use this character as next input
RNN_CharToOneHot(*rnn, char)
RNN_Forward(*rnn, 0)
Next
ProcedureReturn result
EndProcedure
; Get top N predictions with probabilities
Procedure RNN_GetTopPredictions(*rnn.RNN, List predictions.s(), List probs.f(), topN.i = 5)
Protected i, j, maxIdx
Protected maxProb.f
Protected Dim used.i(*rnn\vocabSize - 1)
ClearList(predictions())
ClearList(probs())
For i = 1 To topN
maxProb = -1.0
maxIdx = -1
;Find highest unused probability
For j = 0 To *rnn\vocabSize - 1
If (Not used(j) And *rnn\outputs(j) > maxProb)
maxProb = *rnn\outputs(j)
maxIdx = j
EndIf
Next
If maxIdx >= 0
used(maxIdx) = 1
AddElement(predictions())
predictions() = *rnn\indexToChar(maxIdx)
AddElement(probs())
probs() = maxProb
EndIf
Next
EndProcedure
Global Code.s = "For i = 0 To 10" + #CRLF$ +
" Debug i" + #CRLF$ +
"Next" + #CRLF$ +
"For j = 0 To 5" + #CRLF$ +
" Debug j" + #CRLF$ +
"Next" + #CRLF$
; Create RNN (will resize after building vocabulary)
Global size,*rnn.RNN = RNN_New() ; Placeholder vocab size
; Build vocabulary from training text
size = RNN_BuildVocabulary(*rnn, Code)
RNN_Init(*rnn,size,60,0.021)
; Train
RNN_TrainOnText(*rnn, Code, 500)
; Test predictions
Debug ""
Debug "Predict 59 chars from 'For i =' "
Debug RNN_Predict(*rnn, "For i = ", 59)
Debug ""
Debug "Top 5 predictions after 'For j '"
NewList predictions.s()
NewList probs.f()
RNN_CharToOneHot(*rnn, "F")
RNN_Forward(*rnn, 0)
RNN_CharToOneHot(*rnn, "o")
RNN_Forward(*rnn, 0)
RNN_CharToOneHot(*rnn, "r")
RNN_Forward(*rnn, 0)
RNN_CharToOneHot(*rnn, " ")
RNN_Forward(*rnn, 0)
RNN_CharToOneHot(*rnn, "j")
RNN_Forward(*rnn, 0)
RNN_CharToOneHot(*rnn, " ")
RNN_Forward(*rnn, 0)
RNN_GetTopPredictions(*rnn, predictions(), probs(), 5)
ForEach predictions()
Debug "'" + predictions() + "' : " + StrF(probs(), 4)
NextElement(probs())
Next
