# Pastebin UhpJTpUF diff --git a/src/mlpack/methods/ann/gan_impl.hpp b/src/mlpack/methods/ann/gan_impl.hpp index 0b787443c..97f945979 100644 --- a/src/mlpack/methods/ann/gan_impl.hpp +++ b/src/mlpack/methods/ann/gan_impl.hpp @@ -136,7 +136,7 @@ double GAN::Evaluate( std::move(boost::apply_visitor( outputParameterVisitor, discriminator.network.back())), std::move(currentTarget)); - noise.imbue( [&]() { return noiseFunction(randGen);} ); + noise.imbue( [&]() { return noiseFunction();} ); generator.Forward(std::move(noise)); arma::mat temp = boost::apply_visitor( outputParameterVisitor, generator.network.back()); @@ -193,7 +193,7 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) // get the gradients of the discriminator discriminator.Gradient(discriminator.parameter, i, gradientDiscriminator); - noise.imbue( [&]() { return noiseFunction(randGen);} ); + noise.imbue( [&]() { return noiseFunction();} ); generator.Forward(std::move(noise)); discriminator.predictors.col(numFunctions) = boost::apply_visitor( outputParameterVisitor, generator.network.back()); @@ -204,7 +204,7 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) gradientDiscriminator += noiseGradientDiscriminator; - if (currentBatch % generatorUpdateStep == 0 && preTrainSize != 0) + if (currentBatch % generatorUpdateStep == 0 && preTrainSize == 0) { // Minimize log(1 - D(G(noise))) // pass the error from discriminator to generator @@ -220,12 +220,13 @@ Gradient(const arma::mat& /*parameters*/, const size_t i, arma::mat& gradient) gradientGenerator = -gradientGenerator; gradientGenerator *= multiplier; - +/* if (counter % batchSize == 0) { Log::Info << "gradientDiscriminator = " << std::max(std::fabs(gradientDiscriminator.min()), std::fabs(gradientDiscriminator.max())) << std::endl; Log::Info << "gradientGenerator = " << std::max(std::fabs(gradientGenerator.min()), std::fabs(gradientGenerator.max())) << std::endl; } +*/ } counter++; if (counter >= numFunctions)