基于MATLAB的麻雀搜索算法实战
今天给大家分享麻雀搜索算法的代码实战,本次分享主要从算法原理和代码实战展开。需要了解更多算法代码的,可以点击文章左下角的阅读全文,进行获取哦~需要了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦,下一期分享的内容就是你想了解的内容~
一、算法原理
麻雀搜索算法(Sparrow Search Algorithm, SSA)。该算法由东华大学的Xue和Shen于2020年提出,该算法通过模拟麻雀种群觅食的行为,在种群中设定发现者、加入者、侦察者 3种身份的个体,通过叠加侦查预警机制,迭代更新群体觅食位置,以获得全局最优的觅食资源,从而获得参数的最优解。具有较高的全局寻优和求解能力。麻雀搜索算法的伪代码如下:
麻雀搜索算法的算法流程图如下图所示:
二、代码实战
以SSA优化DBN手写数字集分类为例
%%主函数
%
clc;
clear;
close all;
addpath('D:\学习资料\DBN11\SSA2')%%记得更换路径
addpath('D:\学习资料\DBN11\DBN')
load mnist_uint8;
train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y);
% [train_x,train_y,test_x,test_y,RealTrainYlabel,RealTestLabel] = TrainTestSplit(0.8);
%
[Params,CostFunction] = ParameterDefinition(train_x,train_y);
% %% SSA
[particle, GlobalBest,SD,GlobalWorst,Predator,Joiner] = Initialization(Params,CostFunction,'SSA');
[particle,GlobalBest] = SSA(particle,GlobalBest,GlobalWorst,SD,Predator,Joiner,Params,CostFunction);
%%SSA函数
function [particle,GlobalBest] = SSA(particle,GlobalBest,GlobalWorst,SD,Predator,Joiner,Params,CostFunction)
MaxIter = Params.MaxIter;
nPop = Params.nPop;
VarMin = Params.VarMin;
VarMax = Params.VarMax;
VarSize = Params.VarSize;
nVar = 3;
BestCost = zeros(1,MaxIter);
%% Main loop
for i = 1:MaxIter
for j = 1:length(Predator)
alarm = randn ;
ST = randn;
if alarm < ST
Predator(j).Position = Predator(j).Position .* exp( -j /MaxIter);
else
Predator(j).Position = Predator(j).Position + randn * ones(VarSize);
end
Predator(j).Position = max(VarMin,Predator(j).Position);
Predator(j).Position = min(VarMax,Predator(j).Position);
Predator(j).Cost = CostFunction(Predator(j).Position);
end
[~,idx] = min([Predator.Cost]);
BestPredator = Predator(idx);
% 加入者更新
for j = 1: nPop - length(Predator)
if j + length(Predator)> nPop/2
Joiner(j).Position = randn .* exp( (GlobalWorst.Position - Joiner(j).Position) / j^2);
else
A = randi([0,1],1,nVar);
A(~A) = -1;
Ahat = A' / (A * A');
Joiner(j).Position = BestPredator.Position + abs(Joiner(j).Position - BestPredator.Position) * Ahat * ones(VarSize);
end
Joiner(j).Position = max(VarMin,Joiner(j).Position);
Joiner(j).Position = min(VarMax,Joiner(j).Position);
Joiner(j).Cost = CostFunction(Joiner(j).Position);
end
% 警觉者更新
for j = 1:length(SD)
if SD(j).Cost > GlobalBest.Cost
SD(j).Position = GlobalBest.Position + randn * abs( SD(j).Position - GlobalBest.Position);
elseif SD(j).Cost == GlobalBest.Cost
SD(j).Position = SD(j).Position + (rand*2-1) * (abs( SD(j).Position - GlobalWorst.Position)./ ((SD(j).Cost - GlobalWorst.Cost) + 0.001));
end
SD(j).Position = max(VarMin,SD(j).Position);
SD(j).Position = min(VarMax,SD(j).Position);
SD(j).Cost = CostFunction(SD(j).Position);
end
% 更新
particle = [Predator;Joiner;SD];
for m = 1:length(particle)
if GlobalBest.Cost > particle(m).Cost
GlobalBest = particle(m);
end
if GlobalWorst.Cost < particle(m).Cost
GlobalWorst = particle(m);
end
end
BestCost(i) = GlobalBest.Cost;
disp(['当前迭代',num2str(i), '最优值为: ', num2str(GlobalBest.Cost)])
end
%% Results
figure;
%plot(BestCost,'LineWidth',2);
semilogy(BestCost,'LineWidth',2);
xlabel('Iteration');
ylabel('Best Cost');
grid on;
end
实验结果:
完整代码后台回复获取哦
部分知识来源于网络,如有侵权请联系作者删除~
今天的分享就到这里了,后续想了解智能算法、机器学习、深度学习和信号处理相关理论的可以后台私信哦~希望大家多多转发点赞加收藏,你们的支持就是我源源不断的创作动力!
作 者 | 华 夏
编 辑 | 华 夏
校 对 | 华 夏
文章来源:matlab学习之家