Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Covid-19 Pre-training 11B model on TPU Pod v3-512 and v3-1024 #253

Closed
agemagician opened this issue Jun 2, 2020 · 14 comments
Closed

Covid-19 Pre-training 11B model on TPU Pod v3-512 and v3-1024 #253

agemagician opened this issue Jun 2, 2020 · 14 comments

Comments

@agemagician
Copy link

agemagician commented Jun 2, 2020

Hi,

We will start large scale training for new two unsupervised datasets on TPU Pod v3-512 and v3-1024. This is research for supporting Covid-19 efforts.

We want to make sure that our configuration is correct, and we have several questions:

This is the following official gin file for 11B model that we will use:

# Configure hyperparameters.
gin.parse_config("""
import t5.models.mesh_transformer
import t5.data.sentencepiece_vocabulary
import mesh_tensorflow.optimize
import mesh_tensorflow.transformer.dataset
import mesh_tensorflow.transformer.learning_rate_schedules
import mesh_tensorflow.transformer.t2t_vocabulary
import mesh_tensorflow.transformer.transformer_layers
import mesh_tensorflow.transformer.utils

# Macros:
# ==============================================================================
d_ff = 65536
d_kv = 128
d_model = 1024
dropout_rate = 0.1
inputs_length = 512
mean_noise_span_length = 3.0
MIXTURE_NAME = 'unsupervised_bio_denoising_task'
noise_density = 0.15
num_heads = 128
num_layers = 24
targets_length = 512
#init_checkpoint = "gs://t5-data/pretrained_models/11B/model.ckpt-1000000"
tokens_per_batch = 1048576 

# Parameters for AdafactorOptimizer:
# ==============================================================================
AdafactorOptimizer.beta1 = 0.0
AdafactorOptimizer.clipping_threshold = 1.0
AdafactorOptimizer.decay_rate = None
AdafactorOptimizer.epsilon1 = 1e-30
AdafactorOptimizer.epsilon2 = 0.001
AdafactorOptimizer.factored = True
AdafactorOptimizer.min_dim_size_to_factor = 128
AdafactorOptimizer.multiply_by_parameter_scale = True

# Parameters for Bitransformer:
# ==============================================================================
Bitransformer.shared_embedding = True

# Parameters for denoise:
# ==============================================================================
denoise.inputs_fn = @preprocessors.noise_span_to_unique_sentinel
denoise.noise_density = %noise_density
denoise.noise_mask_fn = @preprocessors.random_spans_noise_mask
denoise.targets_fn = @preprocessors.nonnoise_span_to_unique_sentinel

# Parameters for decoder/DenseReluDense:
# ==============================================================================
decoder/DenseReluDense.dropout_rate = %dropout_rate
decoder/DenseReluDense.hidden_size = %d_ff

# Parameters for encoder/DenseReluDense:
# ==============================================================================
encoder/DenseReluDense.dropout_rate = %dropout_rate
encoder/DenseReluDense.hidden_size = %d_ff

# Parameters for decoder/EncDecAttention:
# ==============================================================================
# None.

# Parameters for get_sentencepiece_model_path:
# ==============================================================================
get_sentencepiece_model_path.mixture_or_task_name = %MIXTURE_NAME

# Parameters for get_variable_dtype:
# ==============================================================================
get_variable_dtype.activation_dtype = 'bfloat16'

# Parameters for decoder/LayerStack:
# ==============================================================================
decoder/LayerStack.dropout_rate = %dropout_rate
decoder/LayerStack.norm_epsilon = 1e-06

# Parameters for encoder/LayerStack:
# ==============================================================================
encoder/LayerStack.dropout_rate = %dropout_rate
encoder/LayerStack.norm_epsilon = 1e-06

# Parameters for learning_rate_schedule_noam:
# ==============================================================================
learning_rate_schedule_noam.linear_decay_fraction = 0.1
learning_rate_schedule_noam.multiplier = 1.0
learning_rate_schedule_noam.offset = 0
learning_rate_schedule_noam.warmup_steps = 40000

# Parameters for make_bitransformer:
# ==============================================================================
make_bitransformer.decoder_name = 'decoder'
make_bitransformer.encoder_name = 'encoder'

# Parameters for decoder/make_layer_stack:
# ==============================================================================
decoder/make_layer_stack.block_scope = True
decoder/make_layer_stack.layers = \
    [@mesh_tensorflow.transformer.transformer_layers.SelfAttention,
     @mesh_tensorflow.transformer.transformer_layers.EncDecAttention,
     @mesh_tensorflow.transformer.transformer_layers.DenseReluDense]
decoder/make_layer_stack.num_layers = %num_layers

# Parameters for encoder/make_layer_stack:
# ==============================================================================
encoder/make_layer_stack.block_scope = True
encoder/make_layer_stack.layers = \
    [@mesh_tensorflow.transformer.transformer_layers.SelfAttention,
     @mesh_tensorflow.transformer.transformer_layers.DenseReluDense]
encoder/make_layer_stack.num_layers = %num_layers

# Parameters for mesh_train_dataset_fn:
# ==============================================================================
mesh_train_dataset_fn.mixture_or_task_name = %MIXTURE_NAME
mesh_train_dataset_fn.use_cached = False

# Parameters for noise_span_to_unique_sentinel:
# ==============================================================================
# None.

# Parameters for nonnoise_span_to_unique_sentinel:
# ==============================================================================
# None.

# Parameters for pack_dataset:
# ==============================================================================
# None

# Parameters for pack_or_pad:
# ==============================================================================
# None.

# Parameters for random_spans_helper:
# ==============================================================================
random_spans_helper.extra_tokens_per_span_inputs = 1
random_spans_helper.extra_tokens_per_span_targets = 1
random_spans_helper.inputs_length = %inputs_length
random_spans_helper.mean_noise_span_length = %mean_noise_span_length
random_spans_helper.noise_density = %noise_density

# Parameters for targets_length/random_spans_helper:
# ==============================================================================
targets_length/random_spans_helper.extra_tokens_per_span_inputs = 1
targets_length/random_spans_helper.extra_tokens_per_span_targets = 1
targets_length/random_spans_helper.inputs_length = %inputs_length
targets_length/random_spans_helper.mean_noise_span_length = %mean_noise_span_length
targets_length/random_spans_helper.noise_density = %noise_density

# Parameters for random_spans_noise_mask:
# ==============================================================================
random_spans_noise_mask.mean_noise_span_length = %mean_noise_span_length

# Parameters for targets_length/random_spans_targets_length:
# ==============================================================================
# None.

# Parameters for random_spans_tokens_length:
# ==============================================================================
# None.

# Parameters for rate_unsupervised:
# ==============================================================================
#rate_unsupervised.value = 133000000.0

# Parameters for reduce_concat_tokens:
# ==============================================================================
reduce_concat_tokens.batch_size = 128
reduce_concat_tokens.feature_key = 'targets'

# Parameters for run:
# ==============================================================================
run.autostack = True
run.batch_size = ('tokens_per_batch', %tokens_per_batch)
run.dataset_split = 'train'
run.eval_checkpoint_step = None
run.eval_dataset_fn = None
run.eval_summary_dir = None
run.export_path = ''
run.iterations_per_loop = 100
run.keep_checkpoint_max = None
run.layout_rules = 'batch:batch,d_ff:model,heads:model,vocab:model'
run.learning_rate_schedule = @learning_rate_schedules.learning_rate_schedule_noam
run.mesh_shape = @mesh_tensorflow.transformer.utils.tpu_mesh_shape()
run.mode = 'train'
run.model_type = 'bitransformer'
run.optimizer = @optimize.AdafactorOptimizer
run.predict_fn = None
run.save_checkpoints_steps = 2400
run.sequence_length = {'inputs': %inputs_length, 'targets': %targets_length}
run.train_dataset_fn = \
    @t5.models.mesh_transformer.mesh_train_dataset_fn
run.train_steps = 1000000
run.variable_filter = None
run.vocabulary = \
    @t5.data.sentencepiece_vocabulary.SentencePieceVocabulary()

# Parameters for select_random_chunk:
# ==============================================================================
select_random_chunk.feature_key = 'targets'
select_random_chunk.max_length = 65536

# Parameters for decoder/SelfAttention:
# ==============================================================================
decoder/SelfAttention.attention_kwargs = None
decoder/SelfAttention.dropout_rate = %dropout_rate
decoder/SelfAttention.key_value_size = %d_kv
decoder/SelfAttention.num_heads = %num_heads
decoder/SelfAttention.num_memory_heads = 0
decoder/SelfAttention.relative_attention_num_buckets = 32
decoder/SelfAttention.relative_attention_type = 'bias_shared'
decoder/SelfAttention.shared_kv = False

# Parameters for encoder/SelfAttention:
# ==============================================================================
encoder/SelfAttention.attention_kwargs = None
encoder/SelfAttention.dropout_rate = %dropout_rate
encoder/SelfAttention.key_value_size = %d_kv
encoder/SelfAttention.num_heads = %num_heads
encoder/SelfAttention.num_memory_heads = 0
encoder/SelfAttention.relative_attention_num_buckets = 32
encoder/SelfAttention.relative_attention_type = 'bias_shared'
encoder/SelfAttention.shared_kv = False

# Parameters for SentencePieceVocabulary:
# ==============================================================================
SentencePieceVocabulary.extra_ids = 100
SentencePieceVocabulary.sentencepiece_model_file = \
    @t5.models.mesh_transformer.get_sentencepiece_model_path()

# Parameters for serialize_num_microbatches:
# ==============================================================================
serialize_num_microbatches.tokens_per_microbatch_per_replica = 16384

# Parameters for split_tokens:
# ==============================================================================
split_tokens.feature_key = 'targets'
split_tokens.max_tokens_per_segment = @preprocessors.random_spans_tokens_length()
split_tokens.min_tokens_per_segment = None

# Parameters for tpu_estimator_model_fn:
# ==============================================================================
#tpu_estimator_model_fn.init_checkpoint = %init_checkpoint
tpu_estimator_model_fn.outer_batch_size = 1
tpu_estimator_model_fn.tpu_summaries = False

# Parameters for tpu_mesh_shape:
# ==============================================================================
tpu_mesh_shape.model_parallelism = 32
tpu_mesh_shape.tpu_topology = '16x32'

# Parameters for decoder/Unitransformer:
# ==============================================================================
decoder/Unitransformer.d_model = %d_model
decoder/Unitransformer.input_full_attention = False
decoder/Unitransformer.label_smoothing = 0.0
decoder/Unitransformer.loss_fn = None
decoder/Unitransformer.loss_on_targets_only = False
decoder/Unitransformer.max_length = 512
decoder/Unitransformer.positional_embedding = False
decoder/Unitransformer.shared_embedding_and_softmax_weights = True
decoder/Unitransformer.vocab_divisor = 128
decoder/Unitransformer.z_loss = 0.0001
decoder/Unitransformer.loss_denominator = 233472

# Parameters for encoder/Unitransformer:
# ==============================================================================
encoder/Unitransformer.d_model = %d_model
encoder/Unitransformer.input_full_attention = False
encoder/Unitransformer.label_smoothing = 0.0
encoder/Unitransformer.loss_fn = None
encoder/Unitransformer.loss_on_targets_only = False
encoder/Unitransformer.max_length = 512
encoder/Unitransformer.positional_embedding = False
encoder/Unitransformer.shared_embedding_and_softmax_weights = True
encoder/Unitransformer.vocab_divisor = 128
encoder/Unitransformer.z_loss = 0.0001

# Parameters for unsupervised:
# ==============================================================================
unsupervised.preprocessors = \
    [@preprocessors.select_random_chunk,
     @preprocessors.reduce_concat_tokens,
     @preprocessors.split_tokens,
     @preprocessors.denoise]


""")

This is the model initalization that we will use:

# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 512, "targets": 512},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=2400,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
    model_type ='bitransformer'
)

This is the model training that we will use:

TRAIN_STEPS = 1000000 #@param {type: "integer"}
model.train(
    mixture_or_task_name="unsupervised_bio_denoising_task",
    steps = TRAIN_STEPS
)

In this gin we made the following changes:

  • Comment the "init_checkpoint".
  • Change the "learning_rate_schedule_noam.warmup_steps"
  • Comment the "rate_unsupervised.value"
  • Change "train_steps".
  • Comment the "tpu_estimator_model_fn.init_checkpoint".

My questions are:

  1. Are there any other parameters that we need to change on the gin file?
  2. What are the correct values for "model_parallelism",
    "tpu_topology" and "train_batch_size" for TPU Pod v3-512 and V3-1024 ?
  3. What is the recommended "learning_rate_schedule" on MtfModel?
  4. Is there anything else we need to change or to adjust?

@adarob @craffel @sharannarang @nshazeer , Your feedback is highly appreciated.

@agemagician agemagician changed the title Pre-training 11B model on TPU Pod v3-512 and v3-1024 Covid-19 Pre-training 11B model on TPU Pod v3-512 and v3-1024 Jun 2, 2020
@adarob
Copy link
Collaborator

adarob commented Jun 2, 2020

@nshazeer to answer 1 and 3.

For 2:

  • For model_parallelism, 32 should be fine.
  • You should use topology 16x16 for v3-512 and 16x32 for v3-1024.
  • If you want to match our pretraining batch size, you can set batch_size=('tokens_per_batch', 1048576).

@adarob
Copy link
Collaborator

adarob commented Jun 2, 2020

Note that we typically pre-train without the model API so you may want to do use those instructions (https://github.com/google-research/text-to-text-transfer-transformer#training) instead.

@agemagician
Copy link
Author

agemagician commented Jun 2, 2020

@adarob Thanks a lot for your quick reply.

Thanks for clarifying "2", and I will wait for @nshazeer feedback on "1" and "3"

Regarding your recommendation for switching to "t5_mesh_transformer", I was planning initially to do that, but could you help me on how I can convert the model API code to "t5_mesh_transformer" in order to integrate a new task?

This is my code for the new task:

def protein_dataset_fn(split, shuffle_files=True):

  if shuffle_files == True and protein_dataset_path[split] is list:
    random.shuffle(protein_dataset_path[split])

  ds = tf.data.TextLineDataset(protein_dataset_path[split]).skip(0)

  # Map each line to a {"sequence": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["targets"], ex)))

  return ds
def text_preprocessor(dataset):

  def _to_inputs_and_targets(input_dict):
    seq_str = input_dict['targets']
    return {"inputs": seq_str, "targets": seq_str}

  return dataset.map(_to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
t5.data.TaskRegistry.remove("unsupervised_bio_denoising_task")
t5.data.TaskRegistry.add(
    "unsupervised_bio_denoising_task",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=protein_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[text_preprocessor],
    # Supply a function which preprocesses tokens.
    #token_preprocessor=[t5.data.preprocessors.split_tokens,t5.data.preprocessors.denoise],
    token_preprocessor = t5.data.preprocessors.unsupervised,
    # Supply the char-based vocab model.
    output_features=t5.data.Feature(vocabulary=vocab,add_eos=True),
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_examples
)

The dataset consists of several txt files, where each line consists of a single sequence.

@adarob
Copy link
Collaborator

adarob commented Jun 2, 2020

One option is to create a local tasks.py file that adds your task to the registry. You can then add a flag like --module_imports=my_module.tasks to t5_mesh_transformer in order to load the task. See https://github.com/google-research/google-research/tree/master/t5_closed_book_qa for an example of this.

@agemagician
Copy link
Author

Thanks a lot @adarob for your clarification. This is really helpful.
I will follow the example and I will let you know if I have more questions.
I will wait now for @nshazeer feedback.

@agemagician
Copy link
Author

agemagician commented Jun 2, 2020

@adarob I have followed your advice, and I have to say it is much better.

I have test it on Colab on the small mode before I run my large scale training as follows :

!python -m t5.models.mesh_transformer_main\
  --module_import="bio_task" \
  --tpu="grpc://10.40.9.26:8470" \
  --model_dir="gs://xxxx/models/" \
  --gin_file="/usr/local/lib/python3.6/dist-packages/t5/models/gin/objectives/span.gin" \
  --gin_file="/usr/local/lib/python3.6/dist-packages/t5/models/gin/models/t5.1.0.small.gin" \
  --gin_file="/usr/local/lib/python3.6/dist-packages/t5/models/gin/dataset.gin" \
  --gin_file="/usr/local/lib/python3.6/dist-packages/t5/models/gin/learning_rate_schedules/rsqrt_no_ramp_down.gin" \
  --gin_param="MIXTURE_NAME = 'unsupervised_bio_denoising_task'" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'" \
  --gin_param="utils.run.save_checkpoints_steps=5000" \
  --gin_param="utils.run.batch_size=('tokens_per_batch', 196608)" \
  --gin_param="utils.run.train_steps=500000" \
  --gin_param="SentencePieceVocabulary.extra_ids=100" \
  --gin_param="run.perplexity_eval_steps=100"

My current task file is:

import functools
import t5.data
from t5.data import postprocessors as t5_postprocessors
from t5.evaluation import metrics as t5_metrics
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
import tensorflow as tf

TaskRegistry = t5.data.TaskRegistry

trainFilesPath = ['gs://xxxx/001.txt','gs://xxxx/002.txt']
validFilesPath = ['gs://xxxxt/valid.txt']
vocab_model_path = 'bio.model'

vocab = SentencePieceVocabulary(vocab_model_path,extra_ids=100)
print("Vocab has a size of %d\n" %vocab.vocab_size)

protein_dataset_path = {
    "train": trainFilesPath,
    "validation": validFilesPath
}

def protein_dataset_fn(split, shuffle_files=True):

  if shuffle_files == True and protein_dataset_path[split] is list:
    random.shuffle(protein_dataset_path[split])

  ds = tf.data.TextLineDataset(protein_dataset_path[split]).skip(0)

  # Map each line to a {"sequence": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["targets"], ex)))

  return ds


def text_preprocessor(dataset):

  def _to_inputs_and_targets(input_dict):
    seq_str = input_dict['targets']
    return {"inputs": seq_str, "targets": seq_str}

  return dataset.map(_to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)


t5.data.TaskRegistry.remove("unsupervised_bio_denoising_task")
t5.data.TaskRegistry.add(
    "unsupervised_bio_denoising_task",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=protein_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[text_preprocessor],
    # Supply a function which preprocesses tokens.
    #token_preprocessor=[token_noise],
    #token_preprocessor=[t5.data.preprocessors.split_tokens,t5.data.preprocessors.denoise],
    token_preprocessor = t5.data.preprocessors.unsupervised,
    # Supply the char-based vocab model.
    output_features=t5.data.Feature(vocabulary=vocab,add_eos=True),
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy]
)

It will be great @nshazeer if you could confirm that the above command is correct, and I simply need to perform the following changes for the large scale training with the 11B model:

  1. Replace the small model gin file with the 11b model gin file.
  2. Change the tpu_topology.
  3. Change the tokens_per_batch.

@agemagician
Copy link
Author

agemagician commented Jun 6, 2020

@nshazeer Could you please confirm our training schema here ?
We are about to start our large scale training and we don't want to waste a lot of TPU computing because of a mistake in the gin files.

@craffel
Copy link
Contributor

craffel commented Jun 7, 2020

Hi Ahmed, your command looks good though you will need to set utils.tpu_mesh_shape.model_parallelism to be larger than 1 for the larger models. Here is an example command for T5-Base doing C4 unsupervised span-filling training:

t5_mesh_transformer \
  --tpu="..." \
  --gcp_project="..." \
  --tpu_zone="..." \
  --model_dir="..." \
  --gin_file="dataset.gin" \
  --gin_file="models/bi_v1.gin" \
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 1" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = \"8x8\"" \
  --gin_param="MIXTURE_NAME = \"c4_v020_unsupervised\"" \
  --gin_param="utils.run.train_steps = 1000000" \
  --gin_param="utils.run.save_checkpoints_steps = 5000" \
  --gin_param="utils.run.batch_size = (\"tokens_per_batch\", 1048576)"  \
  --gin_file="learning_rate_schedules/rsqrt_no_ramp_down.gin" \
  --gin_file="objectives/span.gin"

Here is the exact command we used for T5.1.1 XXL C4 unsupervised training:

t5_mesh_transformer \
  --tpu="..." \
  --gcp_project="..." \
  --tpu_zone="..." \
  --model_dir="..." \
  --gin_file="dataset.gin" \
  --gin_file="models/t5.1.1.xxl.gin" \
  --gin_file="objectives/span_3_15_u_u.gin" \
  --gin_file="learning_rate_schedules/rsqrt_no_ramp_down.gin" \
  --gin_param="MIXTURE_NAME = \"c4_v020_unsupervised\"" \
  --gin_param="utils.run.train_steps = 1000000" \
  --gin_param="utils.tpu_mesh_shape.tpu_topology = \"16x32\"" \
  --gin_param="utils.tpu_mesh_shape.model_parallelism = 8" \
  --gin_param="utils.run.batch_size = (\"tokens_per_batch\", 1048576)" \
  --gin_param="serialize_num_microbatches.tokens_per_microbatch_per_replica = 4096"

HTH

@agemagician
Copy link
Author

agemagician commented Jun 7, 2020

Thanks a lot @craffel @adarob , I will make sure to add both of you on our paper acknowledgment :)
Your help is really appreciated.
I recommend adding the training scripts somewhere in the repo for other users.

One last question, in our case the sequence length could go up to 40k.
When we trained Bert and Albert we set the "max_position_embeddings" to 40k. During training, we trained it on 2 phases, 1st phase up to 512 (75% of training), and 2nd phase up to 2k (25% of training) sequence length.

In T5 you are using "relative_attention_type" which is set to "bias_shared".

My questions are:

  1. How to set something similar to max_position_embeddings for relative attention?
    I have found "relative_attention_num_buckets" which is by default 32 but I am not sure if this is the right variable that we need to change.
  2. Can we train T5 as we trained Bert in 2 phases with 2 different lengths?

@craffel
Copy link
Contributor

craffel commented Jun 7, 2020

Hi, the relative position allows you to use arbitrary sequence lengths. All relative distances above max_distance get put in the same bucket (the "far away" bucket). max_distance defaults to 128. The full logic is here: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L758

@agemagician
Copy link
Author

@craffel Thanks for the clarification.

In this case, How we change the "max_distance" with "gin_param" ? or we have to hardcode it on mesh TensorFlow ?

@craffel
Copy link
Contributor

craffel commented Jun 7, 2020

All calls to _relative_position_bucket do not override the default, so the cleanest thing to do would be to add an argument to the gin-configurable init kwargs for e.g. SelfAttention https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L146 that allows specifying max_distance and pipe it through as is done with e.g. the relative_attention_num_buckets argument. I think your model will probably work fine with max_distance = 128 though.

@agemagician
Copy link
Author

Thanks a lot @craffel for your help and support.

I will follow your advice to adjust the max_distance and I will re-think about the default max_distance value.

This closes my issue here and again thank you @craffel and @adarob for your help.

@craffel
Copy link
Contributor

craffel commented Jun 7, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants