Форум для обсуждения курса

2-ая лабораторная работа

2-ая лабораторная работа

от Арсений Елисеев -
Количество ответов: 0

points = [

1, 1;   # 1

10, 2;

1, 4;

8, 5;


-2, 5;  # 2

-8, 4;

-3, 2;

-1, 1;


-2, -1; # 3

-1, -4;

-4, -6;

-5, -2;


1, -2;  # 4

12, -2;

11, -5;

2, -7;

];


ht = size(points)(1);

wd = size(points)(2);


H1 =  ones(wd + 1, ht);


for i = 1:wd

  for j = 1:ht

    H1(i, j) = points(j, i);

  endfor

endfor


s1 = 3;

s2 = 4;

N = ht;


W = ones(s2, s1);

Z2 = W * H1;


function retmatrix = softmax(M)

  

  for j = 1:size(M)(2)

    sum(j) = 0;

    

    for i  = 1:size(M)(1)

      sum(j) += exp(M(i, j));

    endfor

  endfor

  

  for i = 1:size(M)(1)

    for j = 1:size(M)(2)

      retmatrix(i, j) = exp(M(i,j)) / sum(j);

    endfor

  endfor

endfunction


H2 = softmax(Z2);


function retVal = lossFunction(X,Y)

  retVal = 0;

  for i = 1:size(X)(1)

    for j = 1: size(X)(2)

      retVal += (X(i,j) - Y(i,j)) * (X(i,j) - Y(i,j)); 

    endfor

  endfor

endfunction


Y = [

1 0 0 0;

1 0 0 0;

1 0 0 0;

1 0 0 0;

0 1 0 0;

0 1 0 0;

0 1 0 0;

0 1 0 0;

0 0 1 0;

0 0 1 0;

0 0 1 0;

0 0 1 0;

0 0 0 1;

0 0 0 1;

0 0 0 1;

0 0 0 1;

];


function retMatrix = softmaxPrime(M)

  ht = size(M)(1);

  wd = size(M)(2);

  mOnes = ones(ht, wd);

  

  retMatrix = softmax(M).*(mOnes - softmax(M));

endfunction


function H2 = directPropagation(H1, W)

  Z2 = W * H1;

  H2 = softmax(Z2);

endfunction


Y = Y';

loss = lossFunction(H2, Y);

test = softmaxPrime(Z2);

delta = (H2 - Y).*softmaxPrime(Z2);

grad = delta * H1';


W0 = W;

alpha = 0.55;

W1 = W0 - alpha * grad;

newH2 = directPropagation(H1, W1);

loss1 = lossFunction(newH2, Y);



#hold on;

#X = -20:0.1:20;

#Y = -20:0.1:20;


#plot(X, 0, 'k');

#plot(0, Y, 'k');


#for i = 1:size(points)(1)

#  if points(i, 1) > 0 && points(i, 2) > 0

#    plot(points(i, 1), points(i, 2), 'r*');elseif points(i, 1) > 0 && points(i, 2) < 0

#    plot(points(i, 1), points(i, 2), 'b*');

#   elseif points(i, 1) > 0 && points(i, 2) < 0

#    plot(points(i, 1), points(i, 2), 'b*');

#   elseif points(i, 1) < 0 && points(i, 2) < 0

#    plot(points(i, 1), points(i, 2), 'g*');

#   else

#    plot(points(i, 1), points(i, 2), 'm*');

#  endif

#    

#endfor