mirror of
				https://github.com/gmihaila/ml_things.git
				synced 2021-10-04 01:29:04 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			250 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			250 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|   "nbformat": 4,
 | |
|   "nbformat_minor": 0,
 | |
|   "metadata": {
 | |
|     "colab": {
 | |
|       "name": "keras_embedding.ipynb",
 | |
|       "version": "0.3.2",
 | |
|       "provenance": [],
 | |
|       "collapsed_sections": []
 | |
|     }
 | |
|   },
 | |
|   "cells": [
 | |
|     {
 | |
|       "cell_type": "markdown",
 | |
|       "metadata": {
 | |
|         "id": "view-in-github",
 | |
|         "colab_type": "text"
 | |
|       },
 | |
|       "source": [
 | |
|         "[View in Colaboratory](https://colab.research.google.com/github/gmihaila/deep_learning_toolbox/blob/master/keras_embedding.ipynb)"
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "metadata": {
 | |
|         "id": "wXrf8GWPZKTL",
 | |
|         "colab_type": "text"
 | |
|       },
 | |
|       "cell_type": "markdown",
 | |
|       "source": [
 | |
|         "### Keras embedding layer for input NN\n",
 | |
|         "\n",
 | |
|         "\n",
 | |
|         "Turns positive integers (indexes) into dense vectors of fixed size. eg. [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]\n",
 | |
|         "\n",
 | |
|         "This layer can only be used as the first layer in a model.\n",
 | |
|         "\n",
 | |
|         "Arguments\n",
 | |
|         "\n",
 | |
|         "#### input_dim: int > 0. \n",
 | |
|         "Size of the vocabulary, i.e. maximum integer index + 1.\n",
 | |
|         "\n",
 | |
|         "#### output_dim: \n",
 | |
|         "int >= 0. Dimension of the dense embedding.\n",
 | |
|         "\n",
 | |
|         "#### embeddings_initializer: \n",
 | |
|         "Initializer for the embeddings matrix (see initializers).\n",
 | |
|         "\n",
 | |
|         "#### embeddings_regularizer: \n",
 | |
|         "Regularizer function applied to the embeddings matrix (see regularizer).\n",
 | |
|         "\n",
 | |
|         "#### embeddings_constraint: \n",
 | |
|         "Constraint function applied to the embeddings matrix (see constraints).\n",
 | |
|         "\n",
 | |
|         "#### mask_zero: \n",
 | |
|         "Whether or not the input value 0 is a special \"padding\" value that should be masked out. This is useful when using recurrent layers which may take variable length input. If this is True then all subsequent layers in the model need to support masking or an exception will be raised. If mask_zero is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1).\n",
 | |
|         "\n",
 | |
|         "#### input_length: \n",
 | |
|         "Length of input sequences, when it is constant. This argument is required if you are going to connect  Flatten then Dense layers upstream (without it, the shape of the dense outputs cannot be computed).\n",
 | |
|         "\n",
 | |
|         "#### Input shape\n",
 | |
|         "2D tensor with shape: (batch_size, sequence_length).\n",
 | |
|         "\n",
 | |
|         "#### Output shape\n",
 | |
|         "3D tensor with shape: (batch_size, sequence_length, output_dim)."
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "metadata": {
 | |
|         "id": "YtpuqMG4ZF8u",
 | |
|         "colab_type": "code",
 | |
|         "colab": {
 | |
|           "base_uri": "https://localhost:8080/",
 | |
|           "height": 2125
 | |
|         },
 | |
|         "outputId": "8b14e1cd-912a-432c-dc59-fcb4b06bdeb1"
 | |
|       },
 | |
|       "cell_type": "code",
 | |
|       "source": [
 | |
|         "from keras.layers import Embedding\n",
 | |
|         "from keras.models import Sequential\n",
 | |
|         "\n",
 | |
|         "import numpy as np\n",
 | |
|         "\n",
 | |
|         "model = Sequential()\n",
 | |
|         "model.add(Embedding(1000, 64, input_length=10))\n",
 | |
|         "# the model will take as input an integer matrix of size (batch, input_length).\n",
 | |
|         "# the largest integer (i.e. word index) in the input should be no larger than 999 (vocabulary size).\n",
 | |
|         "# now model.output_shape == (None, 10, 64), where None is the batch dimension.\n",
 | |
|         "\n",
 | |
|         "input_array = np.random.randint(1000, size=(32, 10))\n",
 | |
|         "\n",
 | |
|         "print input_array.shape\n",
 | |
|         "\n",
 | |
|         "model.compile('rmsprop', 'mse')\n",
 | |
|         "output_array = model.predict(input_array)\n",
 | |
|         "\n",
 | |
|         "assert output_array.shape == (32, 10, 64)\n",
 | |
|         "\n",
 | |
|         "print output_array.shape\n",
 | |
|         "\n",
 | |
|         "print 'INPUT\\n %s'%input_array\n",
 | |
|         "print '\\n------------------------\\n'\n",
 | |
|         "print 'OUTPUT\\n %s'%output_array"
 | |
|       ],
 | |
|       "execution_count": 8,
 | |
|       "outputs": [
 | |
|         {
 | |
|           "output_type": "stream",
 | |
|           "text": [
 | |
|             "(32, 10)\n",
 | |
|             "(32, 10, 64)\n",
 | |
|             "INPUT\n",
 | |
|             " [[200 528 642  16 944 519 608 432 244 332]\n",
 | |
|             " [600 988 600 305  69 632 937 758 329 931]\n",
 | |
|             " [567 868 282 373 939 376 567 775 280 862]\n",
 | |
|             " [229 315 486 496 280 251 289 971 997 795]\n",
 | |
|             " [879 719 399  54 503 360 128 819 540 678]\n",
 | |
|             " [848  91 247 228 526 379 602 419 541 504]\n",
 | |
|             " [560 249 685 744 313 226 837 375 556 104]\n",
 | |
|             " [122 763 751 930 762   8 258   4 934 701]\n",
 | |
|             " [814 995 169 242 852 735 852  84 520 233]\n",
 | |
|             " [359 985 103 308 878 122 519 151  98 569]\n",
 | |
|             " [865 254   3 825 496 199 318  59 603 828]\n",
 | |
|             " [ 32 314 634 805 257  75 864 320 388 800]\n",
 | |
|             " [464 792 132 649 484  91 479 565 585 250]\n",
 | |
|             " [432 576 203 678 241 794 616 219 555 553]\n",
 | |
|             " [591 752 461 136 894 159 582 284 613 824]\n",
 | |
|             " [657 425 884 698 338 966 481 661 818 197]\n",
 | |
|             " [536 828 881 415 115 602 594 364 104 746]\n",
 | |
|             " [982 993 200 104 576 370 772 860 427 941]\n",
 | |
|             " [638 612 491 858 152 772 540 608 956 237]\n",
 | |
|             " [880  89 599 124 857 325 841  51 411  44]\n",
 | |
|             " [320 937  40 630  71 203 200 204 464 597]\n",
 | |
|             " [800 836 545 175 986 223  15 262 732 851]\n",
 | |
|             " [138 679 482 507  98 178 808 202 414 557]\n",
 | |
|             " [330 376 622 926 747 198  47 887 163 890]\n",
 | |
|             " [477 261 564 433 789 697  73 576 918 646]\n",
 | |
|             " [530 747 869 238 995  30 646 858 406 768]\n",
 | |
|             " [545  55 368 737 717 537 387 306 880 325]\n",
 | |
|             " [547 892 298 677  35 265 950 467 561 337]\n",
 | |
|             " [423 315   5 281 878 923 578 215 692 391]\n",
 | |
|             " [  9 291  53 421 590 386 430 232 656 523]\n",
 | |
|             " [783 178 985  17 831 708 672 317 557 902]\n",
 | |
|             " [466 681 876 141 117 421 131 466 374 433]]\n",
 | |
|             "\n",
 | |
|             "------------------------\n",
 | |
|             "\n",
 | |
|             "OUTPUT\n",
 | |
|             " [[[ 0.00865164  0.01983705  0.04505033 ... -0.00238312 -0.03788055\n",
 | |
|             "   -0.04710914]\n",
 | |
|             "  [-0.01556301  0.00349374  0.01383151 ...  0.04222801  0.00680803\n",
 | |
|             "    0.04339765]\n",
 | |
|             "  [-0.03443142  0.00046226 -0.01667678 ...  0.01966157  0.01743304\n",
 | |
|             "    0.0368357 ]\n",
 | |
|             "  ...\n",
 | |
|             "  [ 0.03649311 -0.02483535 -0.00794562 ...  0.04411651 -0.02312388\n",
 | |
|             "   -0.04150546]\n",
 | |
|             "  [ 0.04872609 -0.04603274  0.00992678 ...  0.03292545  0.03204581\n",
 | |
|             "   -0.01178072]\n",
 | |
|             "  [ 0.01837654  0.01134163 -0.03265047 ... -0.01823749 -0.02920009\n",
 | |
|             "    0.0195968 ]]\n",
 | |
|             "\n",
 | |
|             " [[ 0.03222017  0.03443829 -0.01021215 ...  0.04203905 -0.04446417\n",
 | |
|             "   -0.0212751 ]\n",
 | |
|             "  [-0.03588893 -0.03442007  0.01172714 ... -0.01005517 -0.00669943\n",
 | |
|             "    0.02540362]\n",
 | |
|             "  [ 0.03222017  0.03443829 -0.01021215 ...  0.04203905 -0.04446417\n",
 | |
|             "   -0.0212751 ]\n",
 | |
|             "  ...\n",
 | |
|             "  [-0.04552734  0.04322474  0.03684006 ...  0.01172649 -0.01000365\n",
 | |
|             "    0.03827994]\n",
 | |
|             "  [ 0.04299496  0.02182479 -0.04390707 ...  0.0216657   0.04814878\n",
 | |
|             "    0.02286277]\n",
 | |
|             "  [-0.04043214  0.01640402  0.01287574 ...  0.04241255  0.04999056\n",
 | |
|             "   -0.00672325]]\n",
 | |
|             "\n",
 | |
|             " [[-0.00651056  0.04324975 -0.04504127 ...  0.00540512 -0.00668663\n",
 | |
|             "    0.04224309]\n",
 | |
|             "  [ 0.04159831 -0.01547033 -0.00797663 ... -0.00497701  0.01751376\n",
 | |
|             "    0.01042927]\n",
 | |
|             "  [-0.00550368 -0.03758467  0.01928823 ... -0.02278941  0.04511717\n",
 | |
|             "    0.04343316]\n",
 | |
|             "  ...\n",
 | |
|             "  [-0.00912142 -0.02935383 -0.01909176 ...  0.04066915  0.04970035\n",
 | |
|             "   -0.02182375]\n",
 | |
|             "  [-0.04707226 -0.03681147  0.00422608 ...  0.01812426 -0.04906961\n",
 | |
|             "   -0.01781807]\n",
 | |
|             "  [-0.01119517 -0.02465594 -0.03707733 ... -0.02763356 -0.04833227\n",
 | |
|             "    0.01661377]]\n",
 | |
|             "\n",
 | |
|             " ...\n",
 | |
|             "\n",
 | |
|             " [[-0.0397908  -0.0054639   0.0105643  ... -0.01893784  0.04225798\n",
 | |
|             "    0.00075748]\n",
 | |
|             "  [-0.02829655  0.00231215 -0.01918566 ...  0.03280311 -0.01057085\n",
 | |
|             "   -0.04767976]\n",
 | |
|             "  [-0.03857187  0.04810581 -0.00054383 ... -0.03136816  0.00057728\n",
 | |
|             "   -0.04078513]\n",
 | |
|             "  ...\n",
 | |
|             "  [-0.02183427  0.00416877  0.00832856 ...  0.03447245 -0.0171337\n",
 | |
|             "   -0.03963489]\n",
 | |
|             "  [-0.01198441  0.00114625  0.01870538 ... -0.04129554  0.00198193\n",
 | |
|             "    0.00094149]\n",
 | |
|             "  [-0.01652408  0.01825029  0.02139355 ... -0.03807431  0.02012963\n",
 | |
|             "    0.00419591]]\n",
 | |
|             "\n",
 | |
|             " [[-0.03773869 -0.01919643  0.02687104 ...  0.02074088  0.0057366\n",
 | |
|             "    0.00504166]\n",
 | |
|             "  [ 0.00254613 -0.0104785   0.03174616 ... -0.02018751  0.03865314\n",
 | |
|             "   -0.04680493]\n",
 | |
|             "  [ 0.02643328 -0.03378489  0.00561283 ... -0.00417087 -0.03543005\n",
 | |
|             "   -0.0029667 ]\n",
 | |
|             "  ...\n",
 | |
|             "  [-0.01586244 -0.02336215 -0.02179241 ...  0.02804102  0.01542559\n",
 | |
|             "    0.01742068]\n",
 | |
|             "  [-0.04158032 -0.04901189  0.04360433 ... -0.0171026  -0.00407366\n",
 | |
|             "   -0.00472242]\n",
 | |
|             "  [-0.03892908  0.01685915  0.03663139 ...  0.00143757  0.03873273\n",
 | |
|             "    0.01560397]]\n",
 | |
|             "\n",
 | |
|             " [[ 0.03947726 -0.00067004  0.00936266 ...  0.02809821 -0.01665358\n",
 | |
|             "    0.04616834]\n",
 | |
|             "  [-0.01025372  0.012909   -0.04810945 ... -0.0468235   0.02956841\n",
 | |
|             "   -0.02124548]\n",
 | |
|             "  [ 0.00653504 -0.02408533  0.03635801 ...  0.00175277  0.0384577\n",
 | |
|             "   -0.0279786 ]\n",
 | |
|             "  ...\n",
 | |
|             "  [ 0.03947726 -0.00067004  0.00936266 ...  0.02809821 -0.01665358\n",
 | |
|             "    0.04616834]\n",
 | |
|             "  [-0.01084725 -0.03696607 -0.02323765 ...  0.03600262 -0.03659841\n",
 | |
|             "   -0.00410198]\n",
 | |
|             "  [ 0.0070935  -0.00723008 -0.0198721  ... -0.04326996 -0.01040084\n",
 | |
|             "    0.00187946]]]\n"
 | |
|           ],
 | |
|           "name": "stdout"
 | |
|         }
 | |
|       ]
 | |
|     },
 | |
|     {
 | |
|       "metadata": {
 | |
|         "id": "pWroniz4bFv4",
 | |
|         "colab_type": "text"
 | |
|       },
 | |
|       "cell_type": "markdown",
 | |
|       "source": [
 | |
|         "ref: https://keras.io/layers/embeddings/#embedding "
 | |
|       ]
 | |
|     }
 | |
|   ]
 | |
| } | 
