movier/imdb_recommendation.ipynb
2023-08-06 22:48:47 +04:00

518 lines
17 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from ast import literal_eval\n",
"from functools import reduce"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"input_film = 'tt0816692'\n",
"\n",
"trained = {'basics': {}, 'principals': {}, 'ratings': {}}\n",
"\n",
"def drop_by_tconst(df, tconst: str, inplace=True) -> pd.DataFrame:\n",
" return df.drop(df[df['tconst'] == tconst].index, inplace=inplace)\n",
"\n",
"# Basics\n",
"\n",
"df = pd.read_csv('./IMDB_data_sets/filtered/basics.csv', dtype={'tconst': str, 'startYear': 'Int16', 'genres': str})\n",
"df['genres'].fillna('', inplace=True)\n",
"\n",
"cv = CountVectorizer(dtype=np.int8, token_pattern=\"(?u)[\\w'-]+\")\n",
"count_matrix = cv.fit_transform(df['genres'])\n",
"\n",
"trained['basics']['genres'] = pd.DataFrame(\n",
" {\n",
" 'genres': cosine_similarity(count_matrix[df[df['tconst'] == input_film].index[0]], count_matrix)[0],\n",
" 'tconst': df['tconst']\n",
" }\n",
" )\n",
"\n",
"drop_by_tconst(trained['basics']['genres'], input_film)\n",
"\n",
"trained['basics']['genres'].sort_values(ascending=False, by='genres', inplace=True, ignore_index=True)\n",
"trained['basics']['genres'].drop('genres', axis=1, inplace=True)\n",
"\n",
"\n",
"year = int(df[df['tconst'] == input_film].startYear.iloc[0])\n",
"\n",
"trained['basics']['years'] = pd.DataFrame(\n",
" {\n",
" 'years': df['startYear'],\n",
" 'tconst': df['tconst']\n",
" }\n",
")\n",
"\n",
"drop_by_tconst(trained['basics']['years'], input_film)\n",
"trained['basics']['years'].sort_values(by='years', key=lambda x: abs(year-x), inplace=True, ignore_index=True)\n",
"trained['basics']['years'].drop('years', axis=1, inplace=True)\n",
"trained['basics']['years'].reset_index(names='years_index', inplace=True)\n",
"\n",
"# Principals\n",
"\n",
"df = pd.read_csv('./IMDB_data_sets/filtered/principals.csv', dtype={'tconst': str, 'nconst': str}, usecols=['tconst', 'nconst'])\n",
"df.nconst = df.nconst.apply(lambda n: ','.join(literal_eval(n)))\n",
"\n",
"cv = CountVectorizer(dtype=np.int8, token_pattern=\"(?u)[\\w'-]+\")\n",
"count_matrix = cv.fit_transform(df['nconst'])\n",
"\n",
"trained['principals']['nconst'] = pd.DataFrame(\n",
" {\n",
" 'nconst': cosine_similarity(count_matrix[df[df['tconst'] == input_film].index[0]], count_matrix)[0],\n",
" 'tconst': df['tconst']\n",
" }\n",
" )\n",
"\n",
"drop_by_tconst(trained['principals']['nconst'], input_film)\n",
"trained['principals']['nconst'].sort_values(ascending=False, by='nconst', inplace=True, ignore_index=True)\n",
"trained['principals']['nconst'].drop('nconst', axis=1, inplace=True)\n",
"trained['principals']['nconst'].reset_index(names='nconst_index', inplace=True)\n",
"\n",
"# Ratings\n",
"\n",
"df = pd.read_csv('./IMDB_data_sets/filtered/ratings.csv', dtype={'tconst': str, 'averageRating': float, 'numVotes': 'Int64'})\n",
"\n",
"rating = float(df[df['tconst'] == input_film].averageRating.iloc[0])\n",
"votes = int(df[df['tconst'] == input_film].numVotes.iloc[0])\n",
"\n",
"drop_by_tconst(df, input_film)\n",
"\n",
"trained['ratings']['ratings'] = df.sort_values(by='averageRating', key=lambda x: abs(rating-x), ignore_index=True)\n",
"trained['ratings']['ratings'].drop(['averageRating', 'numVotes'], axis=1, inplace=True)\n",
"trained['ratings']['ratings'].reset_index(names='ratings_index', inplace=True)\n",
"\n",
"df.drop('averageRating', axis=1, inplace=True)\n",
"\n",
"trained['ratings']['votes'] = df.sort_values(by='numVotes', key=lambda x: abs(votes-x), ignore_index=True)\n",
"trained['ratings']['votes'].drop('numVotes', axis=1, inplace=True)\n",
"trained['ratings']['votes'].reset_index(names='votes_index', inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"merged = reduce(lambda left, right: pd.merge(\n",
" left,\n",
" right,\n",
" on=['tconst'],\n",
" how='outer'\n",
" ), [\n",
" trained['basics']['genres'],\n",
" trained['basics']['years'],\n",
" trained['principals']['nconst'],\n",
" trained['ratings']['ratings'],\n",
" trained['ratings']['votes']\n",
" ])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tconst</th>\n",
" <th>years_index</th>\n",
" <th>nconst_index</th>\n",
" <th>ratings_index</th>\n",
" <th>votes_index</th>\n",
" <th>average</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>tt4255564</td>\n",
" <td>744690</td>\n",
" <td>297616</td>\n",
" <td>669670</td>\n",
" <td>669670</td>\n",
" <td>476329.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>tt2203897</td>\n",
" <td>27293</td>\n",
" <td>9705</td>\n",
" <td>602978</td>\n",
" <td>602978</td>\n",
" <td>248591.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>tt0355627</td>\n",
" <td>344502</td>\n",
" <td>708640</td>\n",
" <td>318038</td>\n",
" <td>205177</td>\n",
" <td>315271.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>tt15387378</td>\n",
" <td>710214</td>\n",
" <td>98486</td>\n",
" <td>540358</td>\n",
" <td>540358</td>\n",
" <td>377883.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>tt5155340</td>\n",
" <td>72975</td>\n",
" <td>386406</td>\n",
" <td>103102</td>\n",
" <td>152733</td>\n",
" <td>143044.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777383</th>\n",
" <td>tt1230211</td>\n",
" <td>189871</td>\n",
" <td>599826</td>\n",
" <td>278434</td>\n",
" <td>25539</td>\n",
" <td>374210.6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777384</th>\n",
" <td>tt12302076</td>\n",
" <td>149946</td>\n",
" <td>599825</td>\n",
" <td>483066</td>\n",
" <td>483066</td>\n",
" <td>498657.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777385</th>\n",
" <td>tt1230206</td>\n",
" <td>189885</td>\n",
" <td>599823</td>\n",
" <td>301969</td>\n",
" <td>25847</td>\n",
" <td>378981.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777386</th>\n",
" <td>tt1230179</td>\n",
" <td>255769</td>\n",
" <td>599809</td>\n",
" <td>483065</td>\n",
" <td>483065</td>\n",
" <td>519818.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777387</th>\n",
" <td>tt9916754</td>\n",
" <td>39373</td>\n",
" <td>777387</td>\n",
" <td>777387</td>\n",
" <td>777387</td>\n",
" <td>629784.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>777388 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" tconst years_index nconst_index ratings_index votes_index \n",
"0 tt4255564 744690 297616 669670 669670 \\\n",
"1 tt2203897 27293 9705 602978 602978 \n",
"2 tt0355627 344502 708640 318038 205177 \n",
"3 tt15387378 710214 98486 540358 540358 \n",
"4 tt5155340 72975 386406 103102 152733 \n",
"... ... ... ... ... ... \n",
"777383 tt1230211 189871 599826 278434 25539 \n",
"777384 tt12302076 149946 599825 483066 483066 \n",
"777385 tt1230206 189885 599823 301969 25847 \n",
"777386 tt1230179 255769 599809 483065 483065 \n",
"777387 tt9916754 39373 777387 777387 777387 \n",
"\n",
" average \n",
"0 476329.2 \n",
"1 248591.0 \n",
"2 315271.8 \n",
"3 377883.8 \n",
"4 143044.0 \n",
"... ... \n",
"777383 374210.6 \n",
"777384 498657.4 \n",
"777385 378981.8 \n",
"777386 519818.8 \n",
"777387 629784.2 \n",
"\n",
"[777388 rows x 6 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"merged['average'] = (merged.index*20 + merged.years_index*20 + merged.nconst_index*20 + merged.ratings_index*20 + merged.votes_index*20) / (5*20)\n",
"merged"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tconst</th>\n",
" <th>years_index</th>\n",
" <th>nconst_index</th>\n",
" <th>ratings_index</th>\n",
" <th>votes_index</th>\n",
" <th>average</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8695</th>\n",
" <td>tt2338151</td>\n",
" <td>7775</td>\n",
" <td>12586</td>\n",
" <td>23860</td>\n",
" <td>1151</td>\n",
" <td>10813.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>tt3659388</td>\n",
" <td>49654</td>\n",
" <td>98</td>\n",
" <td>25758</td>\n",
" <td>79</td>\n",
" <td>15120.6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8501</th>\n",
" <td>tt1754656</td>\n",
" <td>30993</td>\n",
" <td>116</td>\n",
" <td>46193</td>\n",
" <td>3247</td>\n",
" <td>17810.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7374</th>\n",
" <td>tt2103281</td>\n",
" <td>11453</td>\n",
" <td>27910</td>\n",
" <td>49618</td>\n",
" <td>347</td>\n",
" <td>19340.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7549</th>\n",
" <td>tt2358592</td>\n",
" <td>54985</td>\n",
" <td>17633</td>\n",
" <td>12668</td>\n",
" <td>9182</td>\n",
" <td>20403.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>758545</th>\n",
" <td>tt13334656</td>\n",
" <td>700841</td>\n",
" <td>672174</td>\n",
" <td>500543</td>\n",
" <td>500543</td>\n",
" <td>626529.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>758577</th>\n",
" <td>tt13336544</td>\n",
" <td>700845</td>\n",
" <td>672184</td>\n",
" <td>500576</td>\n",
" <td>500576</td>\n",
" <td>626551.6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>758587</th>\n",
" <td>tt13335546</td>\n",
" <td>700843</td>\n",
" <td>672231</td>\n",
" <td>500564</td>\n",
" <td>500564</td>\n",
" <td>626557.8</td>\n",
" </tr>\n",
" <tr>\n",
" <th>758590</th>\n",
" <td>tt13335152</td>\n",
" <td>700842</td>\n",
" <td>672247</td>\n",
" <td>500557</td>\n",
" <td>500557</td>\n",
" <td>626558.6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>777387</th>\n",
" <td>tt9916754</td>\n",
" <td>39373</td>\n",
" <td>777387</td>\n",
" <td>777387</td>\n",
" <td>777387</td>\n",
" <td>629784.2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>777388 rows × 6 columns</p>\n",
"</div>"
],
"text/plain": [
" tconst years_index nconst_index ratings_index votes_index \n",
"8695 tt2338151 7775 12586 23860 1151 \\\n",
"14 tt3659388 49654 98 25758 79 \n",
"8501 tt1754656 30993 116 46193 3247 \n",
"7374 tt2103281 11453 27910 49618 347 \n",
"7549 tt2358592 54985 17633 12668 9182 \n",
"... ... ... ... ... ... \n",
"758545 tt13334656 700841 672174 500543 500543 \n",
"758577 tt13336544 700845 672184 500576 500576 \n",
"758587 tt13335546 700843 672231 500564 500564 \n",
"758590 tt13335152 700842 672247 500557 500557 \n",
"777387 tt9916754 39373 777387 777387 777387 \n",
"\n",
" average \n",
"8695 10813.4 \n",
"14 15120.6 \n",
"8501 17810.0 \n",
"7374 19340.4 \n",
"7549 20403.4 \n",
"... ... \n",
"758545 626529.2 \n",
"758577 626551.6 \n",
"758587 626557.8 \n",
"758590 626558.6 \n",
"777387 629784.2 \n",
"\n",
"[777388 rows x 6 columns]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"merged.sort_values(by='average')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "08dff0a1cb2e37beec5bc340112a669cde11fa0a1a1e2fde92884d26090bd6fc"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}