я пытаюсь обучить модель, я использовал код, который можно найти здесь: https://medium.com/@martin.lees/image-recognition-with-machine-learning-in-python- и-tensorflow-b893cd9014d2
Дело в том, что даже когда я просто копирую/вставляю код, у меня возникает проблема, что я действительно не понимаю, почему он у меня есть. Я много искал на tensorflow Github, но ничего не нашел, чтобы решить мою проблему.
Вот трассировка:
Traceback (most recent call last):
File "D:\pokemon\PogoBot\PoGo-Adb\ml_test_data_test.py", line 108, in <module>
tf.app.run(main=main)
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "C:\Users\pierr\anaconda3\lib\site-packages\absl\app.py", line 303, in run
_run_main(main, args)
File "C:\Users\pierr\anaconda3\lib\site-packages\absl\app.py", line 251, in _run_main
sys.exit(main(argv))
File "D:\pokemon\PogoBot\PoGo-Adb\ml_test_data_test.py", line 104, in main
saver.save(sess, "./model")
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\training\saver.py", line 1183, in save
model_checkpoint_path = sess.run(
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 957, in run
result = self._run(None, fetches, feed_dict, options_ptr,
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1180, in _run
results = self._do_run(handle, final_targets, final_fetches,
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1358, in _do_run
return self._do_call(_run_fn, feeds, fetches, targets, options,
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1365, in _do_call
return fn(*args)
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1349, in _run_fn
return self._call_tf_sessionrun(options, feed_dict, fetch_list,
File "C:\Users\pierr\anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1441, in _call_tf_sessionrun
return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe9 in position 109: invalid continuation byte
И вот код:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
from os import listdir
from os.path import isfile, join
import numpy as np
import tensorflow as tf2
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import math
class Capchat:
data_dir = "data_test//"
nb_categories = 9
X_train = None # X is the data array
Y_train = None # Y is the labels array, you'll see this notation pretty often
train_nb = 0 # number of train images
X_test = None
Y_test = None
test_nb = 0 # number of tests images
index = 0 # the index of the array we will fill
def readimg(self, file, label, train = True):
im = cv2.imread(file); # read the image to PIL image
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY).flatten() # put it in black and white and as a vector
# the train var definies if we fill the training dataset or the test dataset
if train :
self.X_train[self.index] = im
self.Y_train[self.index][label - 1] = 1
else :
self.X_test[self.index] = im
self.Y_test[self.index][label - 1] = 1
self.index += 1
def __init__(self):
total_size = [f for f in listdir(self.data_dir + "1/") if isfile(join(self.data_dir + "1/", f))].__len__() # ge the total size of the dataset
self.train_nb = math.floor(total_size * 0.8) # we get 80% of the data to train
self.test_nb = math.ceil(total_size *0.2) # 20% to test
# We fill the arrays with zeroes 840 is the number of pixels in an image
self.X_train = np.zeros((self.train_nb*self.nb_categories, 735), np.int32)
self.Y_train = np.zeros((self.train_nb*self.nb_categories, 3), np.int32)
self.X_test = np.zeros((self.test_nb*self.nb_categories, 735), np.int32)
self.Y_test = np.zeros((self.test_nb*self.nb_categories, 3), np.int32)
# grab all the files
files_1 = [f for f in listdir(self.data_dir+"1/") if isfile(join(self.data_dir+"1/", f))]
files_2 = [f for f in listdir(self.data_dir+"2/") if isfile(join(self.data_dir+"2/", f))]
files_3 = [f for f in listdir(self.data_dir+"3/") if isfile(join(self.data_dir+"3/", f))]
for i in range(self.train_nb):
# add all the files to training dataset
self.readimg(self.data_dir+"1/"+files_1[i], 1)
self.readimg(self.data_dir+"2/"+files_2[i], 2)
self.readimg(self.data_dir+"3/"+files_3[i], 3)
self.index = 0
for i in range (self.train_nb, self.train_nb + self.test_nb):
self.readimg(self.data_dir+"1/" + files_1[i], 1, False)
self.readimg(self.data_dir+"2/" + files_2[i], 2, False)
self.readimg(self.data_dir+"3/" + files_3[i], 3, False)
print("donnée triée")
def main(_):
# Import the data
cap = Capchat()
# Create the model
x = tf.placeholder(tf.float32, [None, 735])
W = tf.Variable(tf.zeros([735, 3]), name = "weights")
b = tf.Variable(tf.zeros([3]), name = "biases")
mult = tf.matmul(x, W) # W * X...
y = tf.add(mult, b, name = "calc") # + b
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 3])
# cost function
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
# optimizer
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# allows to save the model later
saver = tf.train.Saver()
# start a session to run the network on
sess = tf.InteractiveSession()
# initialize global variables
tf.global_variables_initializer().run()
# Train for 1000 steps, notice the cap.X_train and cap.Y_train
for _ in range(1000):
sess.run(train_step, feed_dict = {x: cap.X_train, y_: cap.Y_train})
# Extract one hot encoded output via argmax
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# Test for accuraccy on the testset, notice the cap.X_test and cap.Y_test
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("\nATTENTION RESULTAT ",sess.run(accuracy, feed_dict = {x: cap.X_test,
y_: cap.Y_test}))
# save the model learned weights and biases
saver.save(sess, "./model")
if __name__ == '__main__':
tf.app.run(main=main)
Ошибка была действительно глупой, потому что я на окнах, эта строка
saver.save(sess, "./model")
была причиной ошибки, поэтому я изменил ее следующим образом:
saver.save(sess, "model\\model")
И теперь это работает.