CN116129197A - Fish classification method, system, equipment and medium based on reinforcement learning - Google Patents
Fish classification method, system, equipment and medium based on reinforcement learning Download PDFInfo
- Publication number
- CN116129197A CN116129197A CN202310347212.6A CN202310347212A CN116129197A CN 116129197 A CN116129197 A CN 116129197A CN 202310347212 A CN202310347212 A CN 202310347212A CN 116129197 A CN116129197 A CN 116129197A
- Authority
- CN
- China
- Prior art keywords
- pruning
- network
- fish
- block
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 59
- 230000002787 reinforcement Effects 0.000 title claims abstract description 39
- 238000013138 pruning Methods 0.000 claims abstract description 187
- 241000251468 Actinopterygii Species 0.000 claims abstract description 118
- 238000012549 training Methods 0.000 claims abstract description 55
- 238000013145 classification model Methods 0.000 claims abstract description 35
- 238000012360 testing method Methods 0.000 claims description 25
- 238000012795 verification Methods 0.000 claims description 23
- 238000005259 measurement Methods 0.000 claims description 20
- 238000004422 calculation algorithm Methods 0.000 claims description 16
- 238000005516 engineering process Methods 0.000 claims description 15
- 230000008569 process Effects 0.000 claims description 13
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 7
- 238000005303 weighing Methods 0.000 claims description 5
- 238000005070 sampling Methods 0.000 claims description 3
- 230000006835 compression Effects 0.000 description 8
- 238000007906 compression Methods 0.000 description 8
- 244000141353 Prunus domestica Species 0.000 description 7
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000010276 construction Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 239000012014 frustrated Lewis pair Substances 0.000 description 3
- 238000013140 knowledge distillation Methods 0.000 description 3
- 210000002569 neuron Anatomy 0.000 description 3
- 230000000694 effects Effects 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 238000007667 floating Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 241000894007 species Species 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 238000000342 Monte Carlo simulation Methods 0.000 description 1
- 230000009471 action Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 230000006870 function Effects 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
- XLYOFNOQVPJJNP-UHFFFAOYSA-N water Substances O XLYOFNOQVPJJNP-UHFFFAOYSA-N 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A40/00—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production
- Y02A40/80—Adaptation technologies in agriculture, forestry, livestock or agroalimentary production in fisheries management
- Y02A40/81—Aquaculture, e.g. of fish
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
The invention discloses a fish classification method, a system, equipment and a medium based on reinforcement learning, and relates to the field of fish classification. Training a base line network model by using a sample data set, pruning the trained base line network model to obtain a fish classification model, classifying the fish images to be classified by using the fish classification model to obtain the types of the fish, and improving the classification accuracy and efficiency.
Description
Technical Field
The invention relates to the field of fish classification, in particular to a fish classification method, system, equipment and medium based on reinforcement learning.
Background
The effective classification of fish data is an effective means of studying the water ecosystem. In recent years, deep neural networks (Deep neural network, DNN) have been widely used and have achieved significant achievements in fish data classification tasks. However, due to the difficulty in acquiring fish data, unbalanced sample classification and high parameter and complex calculation amount of DNN, the traditional deep network model has a great challenge in accurately classifying fish data. Currently, a viable approach to this problem is to compress the network model without affecting accuracy. The network pruning technology is a common method in model compression and presents significant advantages in processing complex network model efficiency.
The network pruning technology is to remove redundant parameters and structures in the network to obtain a more sparse network structure, and can be divided into unstructured pruning and structured pruning. The unstructured pruning realizes higher sparsity of the weight matrix by removing unimportant weight values of each layer, for example, song Han et al propose a pruning method based on a threshold value to remove redundant weight values, and consider that the absolute value of the weight value is less than the threshold value as unimportant and delete. The implementation of unstructured pruning requires the assistance of specific software and hardware and introduces additional computational costs. Compared with unstructured pruning, the structured pruning reduces network parameters and calculation cost by removing redundant layers, convolution kernels and channels, and has wider application scenes.
Compared with the important weight parameters of the inheritance base line network, the structure of the pruning network is a key for determining the performance of the pruning network model. The network pruning technique can be regarded as a network architecture search problem, all networks meeting the search conditions are called sub-networks or candidate networks, and a network search space is formed by all sub-networks, and the object of the network search is to search for the optimal sub-network in such a search space.
At present, some network pruning methods are based on manually formulated pruning rates to prune a network model, but the manually formulated pruning rates can cause low network pruning efficiency and easy convergence to local optimum in the actual pruning process. In addition, most network pruning methods prune networks in a hierarchical manner, and cannot fully consider layer-to-layer dependency. The network pruning method is to search a sparse structure of a network in a layer-by-layer manner, and lacks effective utilization of global information of the network structure, and the layering strategy often generates suboptimal compression results. In addition, the network pruning method has serious label dependency, and most pruning methods need to rely on label data in the pruning process, so that the application of the network pruning method is limited when a data label cannot be used in the pruning process. The network pruning technology can be regarded as a neural network architecture search, all networks meeting the search conditions are called sub-networks or candidate networks, and a network search space is formed by all sub-networks, and the aim of the network search is to search the optimal sub-network in the search space. However, the conventional network architecture searching method has a large searching space, so that searching for an optimal sub-network structure is difficult.
In summary, the current network pruning method is adopted to prune the deep network model, so that the fish data is prevented from being classified, and the problems of low classification accuracy and low efficiency exist.
Disclosure of Invention
The invention aims to provide a fish classification method, a system, equipment and a medium based on reinforcement learning so as to improve the accuracy of classifying fish.
In order to achieve the above object, the present invention provides the following solutions:
a fish classification method based on reinforcement learning, comprising:
acquiring an image of fish to be classified;
inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Optionally, the construction process of the fish classification model specifically comprises the following steps:
training the base line network model by using the training set to obtain a trained base line network model;
initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model;
dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers;
inputting the training set into the pruning block network and the baseline block network, and determining a metric score of each pruning block network;
determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the measurement scores;
pruning each pruning block network according to the pruning rate to obtain a plurality of pruned baseline block networks;
and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
Optionally, the constructing a fish classification model according to the pruned pruning network based on the verification set and the test set specifically includes:
inputting the verification set into the pruned pruning block network and the baseline block network respectively to obtain a first output result and a second output result;
calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network;
performing balance calculation on the first mean square error and the first pruning efficiency metric value to obtain a balance calculation value;
selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model;
and adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
Optionally, training the base line network model by using the training set to obtain a trained base line network model, which specifically includes:
carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set;
and training the base line network model by using the processed training set to obtain a trained base line network model.
Optionally, the training set is input into the pruning block network and the baseline block network, and the measurement score of each pruning block network is determined, which specifically includes:
respectively inputting the training set into a first-stage pruning block network and a first-stage baseline block network to obtain pruning block network output results and baseline block network output results;
calculating the mean square error of the pruning block network output result and the baseline block network output result;
calculating the accuracy measurement value of the current baseline block network according to the mean square error;
using the formulaCalculating a pruning efficiency metric value of the current pruning block network; wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network;
determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value;
and inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
Optionally, pruning each baseline block network according to the pruning rate to obtain a plurality of pruned baseline block networks, which specifically includes:
calculating the number of convolution kernels to be pruned of the current layer according to the pruning rate and the number of convolution kernels of each layer of the baseline block network;
calculating importance scores of convolution kernels of each layer of the baseline block network;
pruning is carried out on the convolution kernels of each layer in the baseline block network from small to large according to the importance score and the number of the convolution kernels to be deleted in the current layer, and a baseline block network after pruning is obtained.
Optionally, the calculating the number of convolution kernels to be pruned in the current layer according to the pruning rate and the number of convolution kernels in each layer of the baseline block network specifically includes:
calculating the number of convolution kernels to be pruned of the current layer by using a formula v=o×u; v is the number of convolution kernels to be pruned in the current layer; o is pruning rate of the current layer; u is the number of convolution kernels of the current layer;
when v=u, the number of convolution kernels to be pruned at the current layer is u-1.
A reinforcement learning based fish classification system comprising:
the data acquisition module is used for acquiring the images of the fishes to be classified;
the classification module is used for inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
An electronic device, comprising: the fish classification system comprises a memory and a processor, wherein the memory is used for storing a computer program, and the processor runs the computer program to enable the electronic equipment to execute the fish classification method based on reinforcement learning.
A computer readable storage medium storing a computer program which when executed by a processor implements the reinforcement learning-based fish classification method described above.
According to the specific embodiment provided by the invention, the invention discloses the following technical effects:
according to the fish classification method based on reinforcement learning, the base line network model is trained by using the sample data set, then the trained base line network model is pruned to obtain the fish classification model, the fish image to be classified is classified by using the fish classification model, the types of the fishes are obtained, and the classification accuracy and efficiency are improved.
Drawings
In order to more clearly illustrate the embodiments of the present invention or the technical solutions of the prior art, the drawings that are needed in the embodiments will be briefly described below, it being obvious that the drawings in the following description are only some embodiments of the present invention, and that other drawings may be obtained according to these drawings without inventive effort for a person skilled in the art.
FIG. 1 is a flow chart of a fish classification method based on reinforcement learning provided by the invention;
FIG. 2 is a flow chart of a fish classification model construction process of the invention;
FIG. 3 is a block network supervision pruning method framework diagram based on reinforcement learning algorithm of the present invention;
FIG. 4 is a flowchart of a reinforcement learning algorithm according to the present invention;
FIG. 5 is a diagram of a network pruning algorithm framework of the present invention;
FIG. 6 is a graph of accuracy metric values for a ResNet-20 network of the present invention;
FIG. 7 is a comparison of the pruning of the ResNet-20 network of the present invention;
FIG. 8 is a comparison of the pruning of the ResNet-56 network of the present invention;
fig. 9 is a block diagram of a fish classification system based on reinforcement learning.
Detailed Description
The following description of the embodiments of the present invention will be made clearly and completely with reference to the accompanying drawings, in which it is apparent that the embodiments described are only some embodiments of the present invention, but not all embodiments. All other embodiments, which can be made by those skilled in the art based on the embodiments of the invention without making any inventive effort, are intended to be within the scope of the invention.
The invention aims to provide a fish classification method, a system, equipment and a medium based on reinforcement learning so as to improve the accuracy of classifying fish.
In order that the above-recited objects, features and advantages of the present invention will become more readily apparent, a more particular description of the invention will be rendered by reference to the appended drawings and appended detailed description.
As shown in fig. 1, the fish classification method based on reinforcement learning of the present invention comprises:
step 101: and obtaining an image of the fish to be classified.
Step 102: inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish.
The fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Further, as shown in fig. 2, the construction process of the fish classification model specifically includes:
s1: and training the base line network model by using the training set to obtain a trained base line network model.
Further, the S1 specifically includes:
and carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set.
And training the base line network model by using the processed training set to obtain a trained base line network model.
In practical application, firstly, the fish image data is preprocessed. In order to improve the convergence and generalization ability of the base line network model, the base line network model is trained with the processed fish image data. Firstly, for the types with the quantity less than 300 in the processed fish image data, 5 data enhancement methods such as horizontal overturn, vertical overturn, 90 DEG rotation, 180 DEG rotation, 270 DEG rotation and the like are adopted to expand the data set, and finally the image is uniformly scaled to 224 multiplied by 224. And then dividing the expanded sample data set, and randomly dividing the expanded sample data set into a training set, a verification set and a test set according to the ratio of 8:1:1. Finally, training the base line network model by using the training set after processing.
S2: and initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model. In practical reference, since a larger search space exists in the pruning network model, in order to reduce the space size of the pruning network model for searching the optimal network, the pruning network model is randomly initialized within the range of the compression rate of floating point operations (Floating Point of operations, FLPs).
S3: dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers. In practical application, the trained base line network model and the initial pruning network model are divided into a plurality of block networks according to the same layer: as shown in fig. 3, in order to improve the efficiency of network pruning, the trained base line network model B and the initial pruning network model S are divided into a plurality of block networks according to layers by referring to the idea of knowledge distillation, and the pruning block networks learn the knowledge of the corresponding base line block networks, i-th pruning block network S i And baseline block network B i The input of (a) is the i-1 th baseline block network B i-1 Is provided.
S4: the training set is input into the pruning block network and the baseline block network, and a metric score of each pruning block network is determined.
Further, the step S4 specifically includes:
and respectively inputting the training set into the first-stage pruning block network and the first-stage baseline block network to obtain pruning block network output results and baseline block network output results.
And calculating the mean square error of the pruning block network output result and the baseline block network output result.
In practical application, the formula is utilizedCalculating MSE errors for a pruned block network and a baseline block network, where f (X i W) and g (X) i W' denote the ith baseline block network B, respectively i And a pruning block network S i Is provided.
And calculating the accuracy measurement value of the current baseline block network according to the mean square error. In practical applications, the different network structures are evaluated based on a metric (accuracy metric) defining similar accuracy based on MSE loss, as shown in the following formula:。
using the formulaAnd calculating the pruning efficiency metric value of the current pruning block network.
Wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network. In practice, to further distinguish block networks with similar performance but differing computational efficiency, the present invention uses the flow compression rate of a pruned block network to define the efficiency metric of the model.
And determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value. In practical application, the model performance (accuracy measurement value) and the model efficiency measurement value are combined to obtain a score reflecting the advantages and disadvantages of the pruning network model, and the score is shown in the following formula:。
where α is a weight used to control network model performance and efficiency, higher α values will preferentially reduce more FLPs. For each block network in the pruned network, the goal is to find the block network with the highest metric score R.
And inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
S5: and determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the metric scores.
In practical applications, a reinforcement learning algorithm (Reinforcement Learning, RL) is used to search for the optimal network structure for each pruned block network: as shown in fig. 4, the reinforcement learning algorithm is an optimal solution algorithm of the reward-oriented mechanism, essentially by constructing the solution problem as a markov decision process, and by adjusting the learning strategy through iterative learning to find an optimal solution at each moment. In the technology of the invention, the pruning process of the pruning network is constructed into a Markov decision process, the characterization information of the pruning network model is used as a state, the pruning rate of each layer is used as an action, the model efficiency and the performance are used as rewards, and the better pruning rate of each layer in each block network is searched.
S6: pruning is carried out on each pruning block network according to the pruning rate, and a plurality of pruning block networks after pruning are obtained.
S7: and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
In practical application, pruning is carried out on the pruning block network by using the pruning rate obtained in the step S5, then the pruning block network after pruning is evaluated, and finally the network model with the highest network performance is selected as the network model which is finally searched.
In practical application, the step S7 specifically includes:
and respectively inputting the verification set into the pruned pruning block network and the baseline block network to obtain a first output result and a second output result.
And calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network.
And carrying out weighing calculation on the first mean square error and the first pruning efficiency metric value to obtain a weighing calculation value.
And selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model.
And adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
Network pruning can be divided into network layer pruning and intra-layer convolution kernel pruning, and the technology only prunes intra-layer convolution kernels. The inventive technique uses the weight L1 norm to prune the network model, and pruning the convolution kernels of each layer is shown in fig. 5.
The specific flow is as follows:
(1) The importance of the convolution kernel is ordered. In each layer, the importance scores of the convolution kernels or neurons are calculated and the convolution kernels or neurons are ordered in a small to large manner according to the importance scores.
(2) And calculating the number of convolution kernels to be deleted in the current layer. Assuming that the pruning rate given by the layer is o and the number of convolution kernels is u, the number of convolution kernels to be deleted is v=o×u, and if v is a decimal number, the downward rounding operation is performed on the convolution kernels, and only the integer part is reserved.
(3) The convolution kernel that is not important for the current layer is deleted. If it isThe first v convolution kernels are deleted directly. If v=u, u-1 convolution kernels are deleted, i.e. at least one convolution kernel is reserved, and in order to ensure connectivity between the front and back layers, the convolution kernel with the highest importance score is reserved at that layer.
In order to verify the compression performance of the invention on a Fish classification model, the invention selects a public data set Fish4 knowledges data set to carry out experimental verification on a ResNet-20 network model. The test platform is Ubuntu 18.06, the CPU is AMD 3090X, the GPU is Titan RTX, and the video memory is 24GB.
The Fish4 knowledges dataset is Fish image data collected at underwater viewing stations of the south bay strait, the island and the hubi lake during the period of 10 th 1 th 2010 to 9 th 30 th 2013. The dataset contained 23 fish 27370 images, the number of different categories of images being widely different, with a single top species accounting for approximately 44% of the images and the top 15 species corresponding to 97% of the images. In consideration of the fact that data imbalance in the training set is easy to cause deviation in model training results, data are enhanced, for the types with the number smaller than 300 in the data, 5 data enhancement methods such as horizontal overturning, vertical overturning, 90-degree rotation, 180-degree rotation and 270-degree rotation are adopted to expand the data set, and finally the image is uniformly scaled to 224×224 pixels for subsequent experiments. The data sets were randomly shuffled and then following 8:1: the scale of 1 divides the data into a training set, a validation set and a test set. Finally obtaining a training set image: 29575 sheets; test set image: 3625 sheets; verification set image: 3625 sheets.
In practical application, the network model of the base line selects a ResNet network model, and the ResNet network model mainly comprises residual blocks and residual connection, wherein one residual block comprises a plurality of convolution layers. For a residual block, the size of the input and output feature maps must be equal unless there is a shortcut in the block. The present invention compresses only the convolution layers of each block except the last layer in order to keep the output channel of each block unchanged. Parameters in the training process of the base line network model are set as follows: epoch is set to 10; the batch size is 32; the learning rate size is initialized to 0.001; the optimizer adopts Adam, the momentum size is 0.9, and the weight attenuation size is 5×10 -4 。
To verify the formulaPruning training was performed on the ResNet-20 network and the ResNet-56 network at the Fish4 knowledges dataset. The ResNet-20 network is divided into 3 Block networks, namely Block1, block2 and Block3, and each Block network is subjected to compression experiments, as shown in figure 6, with the increase of FLPs compression rate of each Block network, R a Is also gradually decreasing.
The reinforcement learning algorithm used in the invention is a depth deterministic strategy gradient algorithm. The depth deterministic gradient algorithm comprises an Actor network and a Critic network, wherein the Actor network and the Critic network respectively comprise 2 hidden layers, and each hidden layer comprises 300 neurons. The buffer size was set to 600 and the batch size was set to 32. The learning rate of the Actor network was set to 0.001, and the learning rate of the critic network was set to 0.002. The super parameter τ=0.01 of the target network soft update, and the number of rounds is set to 600.
ResNet-20 network was trained on the Fish4 knowledges dataset with 98.12% accuracy. ResNet-20 network can compress 32.53% of FLOPs, and the accuracy is improved by 0.52%. The ResNet-20 network compression results are shown in FIG. 7, where the variation of the convolution kernels of the layers of the ResNet-20 network can be seen before and after pruning. The experimental result shows that the method can find the redundant structural parameters of the network model and effectively compress the structural parameters.
To further verify the effectiveness of the method in complex network models, the accuracy of the ResNet-56 network test was 98.12% trained on the Fish4 knowledges dataset. FIG. 8 shows pruning results for various layers of ResNet-56, which can prune 48.43% of FLOPs, but with a post-pruning accuracy of 99.22, which can be improved by 1.1%. The experimental result shows that the method can be used for effectively compressing in a complex network model.
The technology combines a reinforcement learning algorithm and knowledge distillation to provide a block network supervision pruning algorithm based on the reinforcement learning algorithm. The invention has the following advantages:
(1) The technology of the invention uses the reinforcement learning algorithm to learn the pruning rate of each layer of the network model, and can dynamically adjust the pruning rate of each layer according to the efficiency and the performance of the network.
(2) In the pruning process, the technology does not prune the network in a layer-by-layer mode, but learns pruning rates of all layers of the network model.
(3) The technology of the invention uses knowledge distillation technology as a reference, and can monitor the pruning network by minimizing the difference between the output characteristics of the pruning network and the base line network without using data tag information in the pruning process.
(4) The technology of the invention refers to the Markov chain Monte Carlo method, and the base line network and the pruning network are divided into the same block networks according to layers, so that pruning can be carried out on each block network at the same time. The technology of the invention can reduce the search space of the network model, effectively compress the network structure and improve the pruning efficiency of the network model.
Example two
In order to perform a corresponding method of the above embodiment to achieve the corresponding functions and technical effects, a fish classification system based on reinforcement learning is provided, as shown in fig. 9, which includes:
the data acquisition module 901 is used for acquiring the images of the fishes to be classified.
The classification module 902 is configured to input the image of the fish to be classified into a fish classification model to obtain a classification result; the classification result is the type of fish.
The fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Example III
The invention also provides an electronic device, comprising: the fish classification method according to the first embodiment is a fish classification method based on reinforcement learning, and the electronic device is a fish classification device.
Example IV
The present invention also provides a computer-readable storage medium storing a computer program which, when executed by a processor, implements the reinforcement learning-based fish classification method of the first embodiment.
In the present specification, each embodiment is described in a progressive manner, and each embodiment is mainly described in a different point from other embodiments, and identical and similar parts between the embodiments are all enough to refer to each other. For the system disclosed in the embodiment, since it corresponds to the method disclosed in the embodiment, the description is relatively simple, and the relevant points refer to the description of the method section.
The principles and embodiments of the present invention have been described herein with reference to specific examples, the description of which is intended only to assist in understanding the methods of the present invention and the core ideas thereof; also, it is within the scope of the present invention to be modified by those of ordinary skill in the art in light of the present teachings. In view of the foregoing, this description should not be construed as limiting the invention.
Claims (10)
1. A fish classification method based on reinforcement learning, comprising:
acquiring an image of fish to be classified;
inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
2. The reinforcement learning-based fish classification method of claim 1, wherein the process of constructing the fish classification model specifically comprises:
training the base line network model by using the training set to obtain a trained base line network model;
initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model;
dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers;
inputting the training set into the pruning block network and the baseline block network, and determining a metric score of each pruning block network;
determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the measurement scores;
pruning is carried out on each pruning block network according to the pruning rate, so that a plurality of pruned pruning networks are obtained;
and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
3. The reinforcement learning-based fish classification method according to claim 2, wherein the constructing a fish classification model from the pruned pruning network based on the verification set and the test set specifically comprises:
inputting the verification set into the pruned pruning block network and the baseline block network respectively to obtain a first output result and a second output result;
calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network;
performing balance calculation on the first mean square error and the first pruning efficiency metric value to obtain a balance calculation value;
selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model;
and adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
4. The reinforcement learning-based fish classification method of claim 2, wherein the training the base line network model using the training set to obtain a trained base line network model specifically comprises:
carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set;
and training the base line network model by using the processed training set to obtain a trained base line network model.
5. The reinforcement learning based fish classification method of claim 2, wherein said inputting said training set into said pruning block network and said baseline block network, determining a metric score for each of said pruning block networks, comprises in particular:
respectively inputting the training set into a first-stage pruning block network and a first-stage baseline block network to obtain pruning block network output results and baseline block network output results;
calculating the mean square error of the pruning block network output result and the baseline block network output result;
calculating the accuracy measurement value of the current baseline block network according to the mean square error;
using the formulaCalculating a pruning efficiency metric value of the current pruning block network; wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network;
determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value;
and inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
6. The reinforcement learning-based fish classification method according to claim 2, wherein pruning is performed on each of the baseline block networks according to the pruning rate to obtain a plurality of pruned baseline block networks, and specifically comprising:
calculating the number of convolution kernels to be pruned of the current layer according to the pruning rate and the number of convolution kernels of each layer of the baseline block network;
calculating importance scores of convolution kernels of each layer of the baseline block network;
pruning is carried out on the convolution kernels of each layer in the baseline block network from small to large according to the importance score and the number of the convolution kernels to be deleted in the current layer, and a baseline block network after pruning is obtained.
7. The reinforcement learning-based fish classification method according to claim 6, wherein the calculating the number of convolution kernels to be pruned in the current layer according to the pruning rate and the number of convolution kernels in each layer of the baseline block network specifically comprises:
calculating the number of convolution kernels to be pruned of the current layer by using a formula v=o×u; v is the number of convolution kernels to be pruned in the current layer; o is pruning rate of the current layer; u is the number of convolution kernels of the current layer;
when v=u, the number of convolution kernels to be pruned at the current layer is u-1.
8. A reinforcement learning-based fish classification system, comprising:
the data acquisition module is used for acquiring the images of the fishes to be classified;
the classification module is used for inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
9. An electronic device, comprising: a memory for storing a computer program, and a processor that runs the computer program to cause the electronic device to perform the reinforcement learning-based fish classification method of any one of claims 1-7.
10. A computer readable storage medium, characterized in that the computer readable storage medium stores a computer program which, when executed by a processor, implements the reinforcement learning based fish classification method of any one of claims 1-7.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310347212.6A CN116129197A (en) | 2023-04-04 | 2023-04-04 | Fish classification method, system, equipment and medium based on reinforcement learning |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310347212.6A CN116129197A (en) | 2023-04-04 | 2023-04-04 | Fish classification method, system, equipment and medium based on reinforcement learning |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116129197A true CN116129197A (en) | 2023-05-16 |
Family
ID=86303034
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310347212.6A Pending CN116129197A (en) | 2023-04-04 | 2023-04-04 | Fish classification method, system, equipment and medium based on reinforcement learning |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116129197A (en) |
Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111340227A (en) * | 2020-05-15 | 2020-06-26 | 支付宝(杭州)信息技术有限公司 | Method and device for compressing business prediction model through reinforcement learning model |
CN111600851A (en) * | 2020-04-27 | 2020-08-28 | 浙江工业大学 | Feature filtering defense method for deep reinforcement learning model |
CN112686382A (en) * | 2020-12-30 | 2021-04-20 | 中山大学 | Convolution model lightweight method and system |
CN112766496A (en) * | 2021-01-28 | 2021-05-07 | 浙江工业大学 | Deep learning model security guarantee compression method and device based on reinforcement learning |
CN113011588A (en) * | 2021-04-21 | 2021-06-22 | 华侨大学 | Pruning method, device, equipment and medium for convolutional neural network |
US20210397965A1 (en) * | 2020-06-22 | 2021-12-23 | Nokia Technologies Oy | Graph Diffusion for Structured Pruning of Neural Networks |
CN114118402A (en) * | 2021-10-12 | 2022-03-01 | 重庆科技学院 | Self-adaptive pruning model compression algorithm based on grouping attention mechanism |
CN115527106A (en) * | 2022-10-21 | 2022-12-27 | 深圳大学 | Imaging identification method and device based on quantitative fish identification neural network model |
CN115600650A (en) * | 2022-11-02 | 2023-01-13 | 华侨大学(Cn) | Automatic convolution neural network quantitative pruning method and equipment based on reinforcement learning and storage medium |
CN115829022A (en) * | 2022-11-16 | 2023-03-21 | 西安交通大学 | CNN network pruning rate automatic search method and system based on reinforcement learning |
-
2023
- 2023-04-04 CN CN202310347212.6A patent/CN116129197A/en active Pending
Patent Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111600851A (en) * | 2020-04-27 | 2020-08-28 | 浙江工业大学 | Feature filtering defense method for deep reinforcement learning model |
CN111340227A (en) * | 2020-05-15 | 2020-06-26 | 支付宝(杭州)信息技术有限公司 | Method and device for compressing business prediction model through reinforcement learning model |
US20210397965A1 (en) * | 2020-06-22 | 2021-12-23 | Nokia Technologies Oy | Graph Diffusion for Structured Pruning of Neural Networks |
CN112686382A (en) * | 2020-12-30 | 2021-04-20 | 中山大学 | Convolution model lightweight method and system |
CN112766496A (en) * | 2021-01-28 | 2021-05-07 | 浙江工业大学 | Deep learning model security guarantee compression method and device based on reinforcement learning |
CN113011588A (en) * | 2021-04-21 | 2021-06-22 | 华侨大学 | Pruning method, device, equipment and medium for convolutional neural network |
CN114118402A (en) * | 2021-10-12 | 2022-03-01 | 重庆科技学院 | Self-adaptive pruning model compression algorithm based on grouping attention mechanism |
CN115527106A (en) * | 2022-10-21 | 2022-12-27 | 深圳大学 | Imaging identification method and device based on quantitative fish identification neural network model |
CN115600650A (en) * | 2022-11-02 | 2023-01-13 | 华侨大学(Cn) | Automatic convolution neural network quantitative pruning method and equipment based on reinforcement learning and storage medium |
CN115829022A (en) * | 2022-11-16 | 2023-03-21 | 西安交通大学 | CNN network pruning rate automatic search method and system based on reinforcement learning |
Non-Patent Citations (3)
Title |
---|
MANAS GUPTA: "Learning to Prune Deep Neural Networks via Reinforcement Learning", 《ARXIV.ORG/ABS/2007.04756》, pages 1 - 11 * |
刘会东: "分块压缩学习剪枝算法", 《小型微型计算机系统》, vol. 44, no. 02, pages 3 * |
刘会东: "基于强化学习的无标签网络剪枝", 模式识别与人工智能, vol. 34, no. 03, pages 2 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110189334B (en) | Medical image segmentation method of residual error type full convolution neural network based on attention mechanism | |
CN109948029B (en) | Neural network self-adaptive depth Hash image searching method | |
US10339450B2 (en) | System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations | |
CN108765506B (en) | Layer-by-layer network binarization-based compression method | |
Giacomello et al. | Doom level generation using generative adversarial networks | |
US20190034784A1 (en) | Fixed-point training method for deep neural networks based on dynamic fixed-point conversion scheme | |
CN110136135B (en) | Segmentation method, device, equipment and storage medium | |
CN107103285A (en) | Face depth prediction approach based on convolutional neural networks | |
CN111105017A (en) | Neural network quantization method and device and electronic equipment | |
CN111242268A (en) | Method for searching convolutional neural network | |
CN107240100B (en) | Image segmentation method and system based on genetic algorithm | |
CN114548591A (en) | Time sequence data prediction method and system based on hybrid deep learning model and Stacking | |
CN110033089A (en) | Deep neural network parameter optimization method and system based on Distributed fusion algorithm | |
CN107563430A (en) | A kind of convolutional neural networks algorithm optimization method based on sparse autocoder and gray scale correlation fractal dimension | |
US20230376777A1 (en) | System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations | |
CN114743027B (en) | Weak supervision learning-guided cooperative significance detection method | |
CN114529793A (en) | Depth image restoration system and method based on gating cycle feature fusion | |
CN111222534A (en) | Single-shot multi-frame detector optimization method based on bidirectional feature fusion and more balanced L1 loss | |
CN111914904A (en) | Image classification method fusing DarkNet and Capsule eNet models | |
CN117151195A (en) | Model optimization method, device, equipment and medium based on inversion normalization | |
CN116129197A (en) | Fish classification method, system, equipment and medium based on reinforcement learning | |
CN116881683A (en) | GA-AM-GRU-based flow industrial energy consumption prediction method | |
CN114937154B (en) | Significance detection method based on recursive decoder | |
CN115375966A (en) | Image countermeasure sample generation method and system based on joint loss function | |
CN113095328A (en) | Self-training-based semantic segmentation method guided by Gini index |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20230516 |